mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-29 18:40:37 +08:00 
			
		
		
		
	Compare commits
	
		
			96 Commits
		
	
	
		
			fixup-arg-
			...
			upload-rev
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b2e8703329 | |||
| 060c326d89 | |||
| 34a1932751 | |||
| e39eac09c1 | |||
| b0c431fee4 | |||
| 9a188eadbe | |||
| 457c7c1b8d | |||
| fb8cd99a2c | |||
| dfee307d54 | |||
| 93e5765611 | |||
| bf488208be | |||
| 2a14472e4c | |||
| 055a953552 | |||
| 692d5ad458 | |||
| 2139df57f4 | |||
| 8f9a77bb6a | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 | |||
| 07e5e8481a | |||
| 88f55d4728 | |||
| e801ebf332 | |||
| 0ae07f05fc | |||
| 7611021100 | |||
| 767e7ccf13 | |||
| 1caa4c1393 | |||
| da701bf58a | |||
| 703664ed31 | |||
| a8a6564fa7 | |||
| c89e0fa9b9 | |||
| 176a601178 | |||
| cfa0c76ddc | |||
| bcc29915f9 | |||
| 6fbff7a9cb | |||
| f7490bd0a9 | |||
| 8069e3bf0c | |||
| c540d1e1d6 | |||
| 967ac581b8 | |||
| 81088d44e8 | |||
| 4a04c005e3 | |||
| 6d3c6daf20 | |||
| 071900fd69 | |||
| 2d2c6b14e0 | |||
| 03edc573b1 | |||
| c841a6c90d | |||
| c7a343f195 | |||
| 8d838f947d | |||
| b87e6fadbe | |||
| fc935d9874 | |||
| 3622e1f8dd | |||
| a7f3b2e8ed | |||
| a6ab5d83ba | |||
| 4f9f1abfb9 | |||
| f94b7780a6 | |||
| bd28883775 | |||
| 498429e322 | |||
| 09c991af4b | |||
| bcf8df5875 | |||
| 239afff6f5 | |||
| c5ec6b900a | |||
| 3a635eaeea | |||
| 32ec496c5a | |||
| 848c6db87b | |||
| fabb8c52d1 | |||
| d66260dd83 | |||
| daac8078fc | |||
| fcb9a80ce6 | |||
| c25bb32e6e | |||
| 2036892762 | |||
| 0f0de049cf | |||
| 59597df03e | |||
| 5e938ede40 | |||
| cf530c283a | |||
| 437f910336 | |||
| 6f1a6067c8 | |||
| 1d14abcef0 | |||
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | |||
| 22eaba2826 | |||
| 9521ba79a0 | |||
| 9861a5bdef | 
							
								
								
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | ||||
| name: Build documentation | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     branches: | ||||
|       - main | ||||
|       - doc-builder* | ||||
|       - v*-release | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main | ||||
|     with: | ||||
|       commit_sha: ${{ github.sha }} | ||||
|       package: kernels | ||||
|     secrets: | ||||
|       hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} | ||||
							
								
								
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| name: Build PR Documentation | ||||
|  | ||||
| on: pull_request | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main | ||||
|     with: | ||||
|       commit_sha: ${{ github.event.pull_request.head.sha }} | ||||
|       pr_number: ${{ github.event.number }} | ||||
|       package: kernels | ||||
							
								
								
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							| @ -8,3 +8,24 @@ jobs: | ||||
|       - uses: actions/checkout@v4 | ||||
|       - name: Run ruff | ||||
|         uses: astral-sh/ruff-action@v3 | ||||
|  | ||||
|   black: | ||||
|     name: Run black check | ||||
|     runs-on: ubuntu-latest | ||||
|     env: | ||||
|       UV_PYTHON_PREFERENCE: only-managed | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|  | ||||
|       - name: Install uv and set the python version | ||||
|         uses: astral-sh/setup-uv@v5 | ||||
|         with: | ||||
|           python-version: 3.12 | ||||
|  | ||||
|       - name: Install black | ||||
|         run: uv pip install black | ||||
|  | ||||
|       - name: Check formatting | ||||
|         run: | | ||||
|           uv run black --check src | ||||
|           uv run black --check tests | ||||
|  | ||||
							
								
								
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,120 @@ | ||||
| name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI | ||||
|  | ||||
| on: push | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     name: Build distribution 📦 | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     steps: | ||||
|       - uses: actions/checkout@v4 | ||||
|         with: | ||||
|           persist-credentials: false | ||||
|       - name: Set up Python | ||||
|         uses: actions/setup-python@v5 | ||||
|         with: | ||||
|           python-version: "3.9" | ||||
|       - name: Install pypa/build | ||||
|         run: >- | ||||
|           python3 -m | ||||
|           pip install | ||||
|           build | ||||
|           --user | ||||
|       - name: Build a binary wheel and a source tarball | ||||
|         run: python3 -m build | ||||
|       - name: Store the distribution packages | ||||
|         uses: actions/upload-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|  | ||||
|   publish-to-pypi: | ||||
|     name: >- | ||||
|       Publish Python 🐍 distribution 📦 to PyPI | ||||
|     if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes | ||||
|     needs: | ||||
|       - build | ||||
|     runs-on: ubuntu-latest | ||||
|     environment: | ||||
|       name: pypi | ||||
|       url: https://pypi.org/p/kernels | ||||
|     permissions: | ||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Publish distribution 📦 to PyPI | ||||
|         uses: pypa/gh-action-pypi-publish@release/v1 | ||||
|  | ||||
|   github-release: | ||||
|     name: >- | ||||
|       Sign the Python 🐍 distribution 📦 with Sigstore | ||||
|       and upload them to GitHub Release | ||||
|     needs: | ||||
|       - publish-to-pypi | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     permissions: | ||||
|       contents: write # IMPORTANT: mandatory for making GitHub Releases | ||||
|       id-token: write # IMPORTANT: mandatory for sigstore | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Sign the dists with Sigstore | ||||
|         uses: sigstore/gh-action-sigstore-python@v3.0.0 | ||||
|         with: | ||||
|           inputs: >- | ||||
|             ./dist/*.tar.gz | ||||
|             ./dist/*.whl | ||||
|       - name: Create GitHub Release | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ github.token }} | ||||
|         run: >- | ||||
|           gh release create | ||||
|           "$GITHUB_REF_NAME" | ||||
|           --repo "$GITHUB_REPOSITORY" | ||||
|           --notes "" | ||||
|       - name: Upload artifact signatures to GitHub Release | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ github.token }} | ||||
|         # Upload to GitHub Release using the `gh` CLI. | ||||
|         # `dist/` contains the built packages, and the | ||||
|         # sigstore-produced signatures and certificates. | ||||
|         run: >- | ||||
|           gh release upload | ||||
|           "$GITHUB_REF_NAME" dist/** | ||||
|           --repo "$GITHUB_REPOSITORY" | ||||
|  | ||||
|   publish-to-testpypi: | ||||
|     name: Publish Python 🐍 distribution 📦 to TestPyPI | ||||
|     needs: | ||||
|       - build | ||||
|     runs-on: ubuntu-latest | ||||
|  | ||||
|     environment: | ||||
|       name: testpypi | ||||
|       url: https://test.pypi.org/p/kernels | ||||
|  | ||||
|     permissions: | ||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||
|  | ||||
|     steps: | ||||
|       - name: Download all the dists | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           name: python-package-distributions | ||||
|           path: dist/ | ||||
|       - name: Publish distribution 📦 to TestPyPI | ||||
|         uses: pypa/gh-action-pypi-publish@release/v1 | ||||
|         with: | ||||
|           repository-url: https://test.pypi.org/legacy/ | ||||
|           skip-existing: true # Only upload when the version is unique. | ||||
							
								
								
									
										34
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -24,7 +24,7 @@ jobs: | ||||
|       max-parallel: 4 | ||||
|       matrix: | ||||
|         python-version: ["3.10", "3.12"] | ||||
|         torch-version: ["2.5.1", "2.6.0"] | ||||
|         torch-version: ["2.6.0", "2.7.0"] | ||||
|  | ||||
|     env: | ||||
|       UV_PYTHON_PREFERENCE: only-managed | ||||
| @ -51,4 +51,34 @@ jobs: | ||||
|         run: uv run mypy src/kernels | ||||
|  | ||||
|       - name: Run tests | ||||
|         run: uv run pytest tests | ||||
|         run: | | ||||
|           uv run pytest tests | ||||
|  | ||||
|       - name: Run staging tests | ||||
|         env: | ||||
|           HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }} | ||||
|         run: | | ||||
|           HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/ | ||||
|         if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0' | ||||
|  | ||||
|       - name: Check kernel conversion | ||||
|         run: | | ||||
|           uv pip install wheel | ||||
|           uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1 | ||||
|           uv pip install triton_layer_norm-0.0.1*.whl | ||||
|           uv run python -c "import triton_layer_norm" | ||||
|  | ||||
|       - name: Check README generation | ||||
|         # For now, just checks that generation doesn't fail. | ||||
|         run: | | ||||
|           uv run kernels generate-readme kernels-community/triton-layer-norm | ||||
|  | ||||
|       - name: Check kernel check | ||||
|         run: | | ||||
|           uv pip install kernel-abi-check | ||||
|           kernels check kernels-community/activation | ||||
|  | ||||
|       - name: Import check without torch | ||||
|         run: | | ||||
|           uv pip uninstall torch | ||||
|           python -c "import kernels" | ||||
|  | ||||
							
								
								
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | ||||
| name: Upload PR Documentation | ||||
|  | ||||
| on: | ||||
|   workflow_run: | ||||
|     workflows: ["Build PR Documentation"] | ||||
|     types: | ||||
|       - completed | ||||
|  | ||||
| jobs: | ||||
|   build: | ||||
|     uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main | ||||
|     with: | ||||
|       package_name: kernels | ||||
|     secrets: | ||||
|       hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} | ||||
|       comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} | ||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
|    1. Definitions. | ||||
|  | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "[]" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright [yyyy] [name of copyright owner] | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| .PHONY: style | ||||
|  | ||||
| export check_dirs := src examples tests | ||||
|  | ||||
| style: | ||||
| 	black ${check_dirs} | ||||
| 	isort ${check_dirs} | ||||
| 	ruff check ${check_dirs} --fix | ||||
							
								
								
									
										23
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								README.md
									
									
									
									
									
								
							| @ -1,5 +1,16 @@ | ||||
| # kernels | ||||
|  | ||||
| <div align="center"> | ||||
| <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||
| <p align="center"> | ||||
|     <a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a> | ||||
|     <a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a> | ||||
|     <a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a> | ||||
|    | ||||
| </p> | ||||
| </div> | ||||
| <hr/> | ||||
|  | ||||
| The Kernel Hub allows Python libraries and applications to load compute | ||||
| kernels directly from the [Hub](https://hf.co/). To support this kind | ||||
| of dynamic loading, Hub kernels differ from traditional Python kernel | ||||
| @ -45,8 +56,12 @@ the Hub. | ||||
|  | ||||
| ## 📚 Documentation | ||||
|  | ||||
| - [Using layers](docs/layers.md) | ||||
| - [Locking kernel versions](docs/locking.md) | ||||
| - [Using kernels in a Docker container](docs/docker.md) | ||||
| - [Kernel requirements](docs/kernel-requirements.md) | ||||
| - [Introduction](docs/source/index.md) | ||||
| - [Installation](docs/source/installation.md) | ||||
| - [Basic usage](docs/source/basic-usage.md) | ||||
| - [Using layers](docs/source/layers.md) | ||||
| - [Locking kernel/layer versions](docs/source/locking.md) | ||||
| - [Environment variables](docs/source/env.md) | ||||
| - [Kernel requirements](docs/source/kernel-requirements.md) | ||||
| - [Frequently Asked Questions](docs/source/faq.md) | ||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
|  | ||||
| @ -1,8 +0,0 @@ | ||||
| # Using kernels in a Docker container | ||||
|  | ||||
| build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: | ||||
|  | ||||
| ```bash | ||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||
| ``` | ||||
| @ -1,79 +0,0 @@ | ||||
| # Layers | ||||
|  | ||||
| A kernel can provide layers in addition to kernel functions. A layer from | ||||
| the Hub can replace the `forward` method of an existing layer for a certain | ||||
| device type. This makes it possible to provide more performant kernels for | ||||
| existing layers. | ||||
|  | ||||
| See [Kernel requirements](kernel-requirements.md) for more information the | ||||
| requirements of Hub layers. | ||||
|  | ||||
| ## Making a layer extensible with kernels from the hub | ||||
|  | ||||
| ### Using a decorator | ||||
|  | ||||
| A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||
| decorator. For example: | ||||
|  | ||||
| ```python | ||||
| @use_kernel_forward_from_hub("SiluAndMul") | ||||
| class SiluAndMul(nn.Module): | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         d = input.shape[-1] // 2 | ||||
|         return F.silu(input[..., :d]) * input[..., d:] | ||||
| ``` | ||||
|  | ||||
| The decorator changes the layer, so that other implementations of the `forward` | ||||
| method can be registered using the name `SiluAndMul`. | ||||
|  | ||||
| ### External layers | ||||
|  | ||||
| An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||
| decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function. | ||||
|  | ||||
| ```python | ||||
| from somelibrary import SiluAndMul | ||||
|  | ||||
| replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
|  | ||||
| The `register_kernel_mapping` call maps the name `SiluAndMul` to actual | ||||
| hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer) | ||||
| section for more information. | ||||
|  | ||||
| **Warning:** we strongly recommend using layers with a decorator, since | ||||
| it signifies that the maintainer intends to keep the `forward` signature | ||||
| compatible with layers from the hub. | ||||
|  | ||||
| ## Registering a hub kernel for a layer | ||||
|  | ||||
| Once a layer is made extensible, users can register hub kernels for it | ||||
| by name using the `register_kernel_mapping` function. For example: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             revision="layers", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
|  | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
|  | ||||
| This will register the kernel mapping in the current context, which is | ||||
| normally global. It is recommended to scope the mapping to where it is | ||||
| used with the `use_kernel_mapping` context manager: | ||||
|  | ||||
| ```python | ||||
| with use_kernel_mapping(kernel_layer_mapping): | ||||
|     # Use the layer for which the mapping is applied. | ||||
|     ... | ||||
| ``` | ||||
|  | ||||
| This ensures that the mapping is not active anymore outside the | ||||
| `with`-scope. | ||||
							
								
								
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | ||||
| - sections: | ||||
|     - local: index | ||||
|       title: Introduction | ||||
|     - local: installation | ||||
|       title: Installation | ||||
|   title: Getting started | ||||
| - sections: | ||||
|     - local: basic-usage | ||||
|       title: Basic Usage | ||||
|     - local: layers | ||||
|       title: Using Layers | ||||
|     - local: locking | ||||
|       title: Locking Kernel Versions | ||||
|     - local: env | ||||
|       title: Environment Variables | ||||
|     - local: faq | ||||
|       title: FAQ | ||||
|   title: Usage Guide | ||||
| - sections: | ||||
|     - local: api/kernels | ||||
|       title: Kernels | ||||
|     - local: api/layers | ||||
|       title: Layers | ||||
|     - local: cli | ||||
|       title: Kernels CLI | ||||
|   title: API Reference | ||||
| - sections: | ||||
|     - local: kernel-requirements | ||||
|       title: Kernel Requirements | ||||
|   title: Developer Guide | ||||
							
								
								
									
										25
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | ||||
| # Kernels API Reference | ||||
|  | ||||
| ## Main Functions | ||||
|  | ||||
| ### get_kernel | ||||
|  | ||||
| [[autodoc]] kernels.get_kernel | ||||
|  | ||||
| ### get_local_kernel | ||||
|  | ||||
| [[autodoc]] kernels.get_local_kernel | ||||
|  | ||||
| ### has_kernel | ||||
|  | ||||
| [[autodoc]] kernels.has_kernel | ||||
|  | ||||
| ## Loading locked kernels | ||||
|  | ||||
| ### load_kernel | ||||
|  | ||||
| [[autodoc]] kernels.load_kernel | ||||
|  | ||||
| ### get_locked_kernel | ||||
|  | ||||
| [[autodoc]] kernels.get_locked_kernel | ||||
							
								
								
									
										49
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | ||||
| # Layers API Reference | ||||
|  | ||||
| ## Making layers kernel-aware | ||||
|  | ||||
| ### use_kernel_forward_from_hub | ||||
|  | ||||
| [[autodoc]] kernels.use_kernel_forward_from_hub | ||||
|  | ||||
| ### replace_kernel_forward_from_hub | ||||
|  | ||||
| [[autodoc]] kernels.replace_kernel_forward_from_hub | ||||
|  | ||||
| ## Registering kernel mappings | ||||
|  | ||||
| ### use_kernel_mapping | ||||
|  | ||||
| [[autodoc]] kernels.use_kernel_mapping | ||||
|  | ||||
| ### register_kernel_mapping | ||||
|  | ||||
| [[autodoc]] kernels.register_kernel_mapping | ||||
|  | ||||
| ## Kernelizing a model | ||||
|  | ||||
| ### kernelize | ||||
|  | ||||
| [[autodoc]] kernels.kernelize | ||||
|  | ||||
| ## Classes | ||||
|  | ||||
| ### Device | ||||
|  | ||||
| [[autodoc]] kernels.Device | ||||
|  | ||||
| ### Mode | ||||
|  | ||||
| [[autodoc]] kernels.Mode | ||||
|  | ||||
| ### LayerRepository | ||||
|  | ||||
| [[autodoc]] kernels.LayerRepository | ||||
|  | ||||
| ### LocalLayerRepository | ||||
|  | ||||
| [[autodoc]] kernels.LocalLayerRepository | ||||
|  | ||||
| ### LockedLayerRepository | ||||
|  | ||||
| [[autodoc]] kernels.LockedLayerRepository | ||||
							
								
								
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | ||||
| # Basic Usage | ||||
|  | ||||
| ## Loading Kernels | ||||
|  | ||||
| Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: | ||||
|  | ||||
| ```python | ||||
| import torch | ||||
| from kernels import get_kernel | ||||
|  | ||||
| # Download optimized kernels from the Hugging Face hub | ||||
| activation = get_kernel("kernels-community/activation") | ||||
|  | ||||
| # Create a random tensor | ||||
| x = torch.randn((10, 10), dtype=torch.float16, device="cuda") | ||||
|  | ||||
| # Run the kernel | ||||
| y = torch.empty_like(x) | ||||
| activation.gelu_fast(y, x) | ||||
|  | ||||
| print(y) | ||||
| ``` | ||||
|  | ||||
| ### Using version bounds | ||||
|  | ||||
| Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||
| You can specify which version to download using Python version specifiers: | ||||
|  | ||||
| ```python | ||||
| import torch | ||||
| from kernels import get_kernel | ||||
|  | ||||
| activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0") | ||||
| ``` | ||||
|  | ||||
| This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It | ||||
| is strongly recommended to specify a version bound, since a kernel author | ||||
| might push incompatible changes to the `main` branch. | ||||
|  | ||||
| ## Checking Kernel Availability | ||||
|  | ||||
| You can check if a specific kernel is available for your environment: | ||||
|  | ||||
| ```python | ||||
| from kernels import has_kernel | ||||
|  | ||||
| # Check if kernel is available for current environment | ||||
| is_available = has_kernel("kernels-community/activation") | ||||
| print(f"Kernel available: {is_available}") | ||||
| ``` | ||||
							
								
								
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| # Kernels CLI Reference | ||||
|  | ||||
| ## Main Functions | ||||
|  | ||||
| ### kernels check | ||||
|  | ||||
| You can use `kernels check` to test compliance of a kernel on the Hub. | ||||
| This currently checks that the kernel: | ||||
|  | ||||
| - Supports the currently-required Python ABI version. | ||||
| - Works on supported operating system versions. | ||||
|  | ||||
| For example: | ||||
|  | ||||
| ```bash | ||||
| $ kernels check kernels-community/flash-attn3 | ||||
| Checking variant: torch28-cxx11-cu128-aarch64-linux | ||||
|   🐍 Python ABI 3.9 compatible | ||||
|   🐧 manylinux_2_28 compatible | ||||
| [...] | ||||
| ``` | ||||
|  | ||||
| ### kernels to-wheel | ||||
|  | ||||
| We strongly recommend downloading kernels from the Hub using the `kernels` | ||||
| package, since this comes with large [benefits](index.md) over using Python | ||||
| wheels. That said, some projects may require deployment of kernels as | ||||
| wheels. The `kernels` utility provides a simple solution to this. You can | ||||
| convert any Hub kernel into a set of wheels with the `to-wheel` command: | ||||
|  | ||||
| ```bash | ||||
| $ kernels to-wheel drbh/img2grey 1.1.2 | ||||
| ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ``` | ||||
|  | ||||
| ### kernels upload | ||||
|  | ||||
| Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload | ||||
| your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`. | ||||
|  | ||||
| **Notes**: | ||||
|  | ||||
| - This will take care of creating a repository on the Hub with the `repo_id` provided. | ||||
| - If a repo with the `repo_id` already exists and if it contains a `build` with the build variant | ||||
|   being uploaded, it will attempt to delete the files existing under it. | ||||
| - Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub. | ||||
							
								
								
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | ||||
| # Environment variables | ||||
|  | ||||
| ## `KERNELS_CACHE` | ||||
|  | ||||
| The directory to use as the local kernel cache. If not set, the cache | ||||
| of the `huggingface_hub` package is used. | ||||
|  | ||||
| ## `DISABLE_KERNEL_MAPPING` | ||||
|  | ||||
| Disables kernel mappings for [`layers`](layers.md). | ||||
							
								
								
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,41 @@ | ||||
| # FAQ | ||||
|  | ||||
| ## Kernel layers | ||||
|  | ||||
| ### Why is the kernelization step needed as a separate step? | ||||
|  | ||||
| In earlier versions of `kernels`, a layer's `forward` method was replaced | ||||
| by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. | ||||
| The new `forward` would dispatch to a kernel based on the device type, | ||||
| whether a model was training, etc. However, this approach was | ||||
| fundamentally incompatible with `torch.compile` since it relied | ||||
| on data-dependent branching. | ||||
|  | ||||
| To avoid branching, we have to make dispatch decisions ahead of time, | ||||
| which is what the `kernelize` function does. | ||||
|  | ||||
| ### Why does kernelization only replace `forward` methods? | ||||
|  | ||||
| There are some other possible approaches. The first is to completely | ||||
| replace existing layers by kernel layers. However, since this would | ||||
| permit free-form layer classes, it would be much harder to validate | ||||
| that layers are fully compatible with the layers that they are | ||||
| replacing. For instance, they could have completely different member | ||||
| variables. Besides that, we would also need to hold on to the original | ||||
| layers, in case we need to revert to the base layers when the model | ||||
| is `kernelize`d again with different options. | ||||
|  | ||||
| A second approach would be to make an auxiliary layer that wraps the | ||||
| original layer and the kernel layer and dispatches to the kernel layer. | ||||
| This wouldn't have the issues of the first approach, because kernel layers | ||||
| could be similarly strict as they are now, and we would still have access | ||||
| to the original layers when `kernelize`-ing the model again. However, | ||||
| this would change the graph structure of the model and would break use | ||||
| cases where programs access the model internals (e.g. | ||||
| `model.layers[0].attention.query_weight`) or rely on the graph structure | ||||
| in other ways. | ||||
|  | ||||
| The approach of `forward`-replacement is the least invasive, because | ||||
| it preserves the original model graph. It is also reversible, since | ||||
| even though the `forward` of a layer _instance_ might be replaced, | ||||
| the corresponding class still has the original `forward`. | ||||
							
								
								
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | ||||
| # Kernels | ||||
|  | ||||
| <div align="center"> | ||||
| <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||
| </div> | ||||
|  | ||||
| The Kernel Hub allows Python libraries and applications to load compute | ||||
| kernels directly from the [Hub](https://hf.co/). To support this kind | ||||
| of dynamic loading, Hub kernels differ from traditional Python kernel | ||||
| packages in that they are made to be: | ||||
|  | ||||
| - **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`. | ||||
| - **Unique**: multiple versions of the same kernel can be loaded in the | ||||
|   same Python process. | ||||
| - **Compatible**: kernels must support all recent versions of Python and | ||||
|   the different PyTorch build configurations (various CUDA versions | ||||
|   and C++ ABIs). Furthermore, older C library versions must be supported. | ||||
|  | ||||
| You can [search for kernels](https://huggingface.co/models?other=kernel) on | ||||
| the Hub. | ||||
							
								
								
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | ||||
| # Installation | ||||
|  | ||||
| Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): | ||||
|  | ||||
| ```bash | ||||
| pip install kernels | ||||
| ``` | ||||
|  | ||||
| # Using kernels in a Docker container | ||||
|  | ||||
| Build and run the reference `examples/basic.py` in a Docker container with the following commands: | ||||
|  | ||||
| ```bash | ||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||
| ``` | ||||
| @ -1,8 +1,11 @@ | ||||
| # Kernel requirements | ||||
| 
 | ||||
| Kernels on the Hub must fulfill the requirements outlined on this page. | ||||
| Kernels on the Hub must fulfill the requirements outlined on this page. By | ||||
| ensuring kernels are compliant, they can be used on a wide range of Linux | ||||
| systems and Torch builds. | ||||
| 
 | ||||
| You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
| to build conforming kernels. | ||||
| to build compliant kernels. | ||||
| 
 | ||||
| ## Directory layout | ||||
| 
 | ||||
| @ -10,52 +13,83 @@ A kernel repository on the Hub must contain a `build` directory. This | ||||
| directory contains build variants of a kernel in the form of directories | ||||
| following the template | ||||
| `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | ||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently | ||||
| recommended build variants are: | ||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. | ||||
| 
 | ||||
| - `torch25-cxx11-cu118-x86_64-linux` | ||||
| - `torch25-cxx11-cu121-x86_64-linux` | ||||
| - `torch25-cxx11-cu124-x86_64-linux` | ||||
| - `torch25-cxx98-cu118-x86_64-linux` | ||||
| - `torch25-cxx98-cu121-x86_64-linux` | ||||
| - `torch25-cxx98-cu124-x86_64-linux` | ||||
| - `torch26-cxx11-cu118-x86_64-linux` | ||||
| - `torch26-cxx11-cu124-x86_64-linux` | ||||
| - `torch26-cxx11-cu126-x86_64-linux` | ||||
| - `torch26-cxx98-cu118-x86_64-linux` | ||||
| - `torch26-cxx98-cu124-x86_64-linux` | ||||
| - `torch26-cxx98-cu126-x86_64-linux` | ||||
| 
 | ||||
| This list will be updated as new PyTorch versions are released. Kernels | ||||
| that are in pure Python (e.g. Triton kernels) only need to provide a | ||||
| single build variant: | ||||
| 
 | ||||
| - `torch-universal` | ||||
| 
 | ||||
| Each variant directory should contain a single directory with the same name | ||||
| Each variant directory must contain a single directory with the same name | ||||
| as the repository (replacing `-` by `_`). For instance, kernels in the | ||||
| `kernels-community/activation` repository have a directories like | ||||
| `build/<variant>/activation`. This directory | ||||
| must be a Python package with an `__init__.py` file. | ||||
| 
 | ||||
| ## Build variants | ||||
| 
 | ||||
| A kernel can be compliant for a specific compute framework (e.g. CUDA) or | ||||
| architecture (e.g. x86_64). For compliance with a compute framework and | ||||
| architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md) | ||||
| must be available for that combination. | ||||
| 
 | ||||
| ## Versioning | ||||
| 
 | ||||
| Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||
| the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) | ||||
| to resolve the version constraints. | ||||
| 
 | ||||
| We recommend using [semver](https://semver.org/) to version kernels. | ||||
| 
 | ||||
| ## Native Python module | ||||
| 
 | ||||
| Kernels will typically contain a native Python module with precompiled | ||||
| compute kernels and bindings. This module must fulfill the following | ||||
| requirements: | ||||
| compute kernels and bindings. This module must fulfill the requirements | ||||
| outlined in this section. For all operating systems, a kernel must not | ||||
| have dynamic library dependencies outside: | ||||
| 
 | ||||
| - Torch; | ||||
| - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
| 
 | ||||
| ## Compatibility with torch.compile | ||||
| 
 | ||||
| The Kernel Hub also encourages to write the kernels in a `torch.compile` | ||||
| compliant way. This helps to ensure that the kernels are compatible with | ||||
| `torch.compile` without introducing any graph breaks and triggering  | ||||
| recompilation which can limit the benefits of compilation. | ||||
| 
 | ||||
| [Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and  | ||||
| recompilation triggers during `torch.compile`. | ||||
| 
 | ||||
| ### Linux | ||||
| 
 | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - Compatible with glibc 2.27 or later. This means that no symbols | ||||
|   from later versions must be used. To archive this, the module should | ||||
|   be built against this glibc version. **Warning:** libgcc must also be | ||||
|   built against glibc 2.27 to avoid leaking symbols. | ||||
| - No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols | ||||
|   must be static. | ||||
| - No dynamic library dependencies outside Torch or CUDA libraries | ||||
|   installed as dependencies of Torch. | ||||
| - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). | ||||
|   This means that the extension **must not** use symbols versions higher than: | ||||
|   - GLIBC 2.28 | ||||
|   - GLIBCXX 3.4.24 | ||||
|   - CXXABI 1.3.11 | ||||
|   - GCC 7.0.0 | ||||
| 
 | ||||
| (These requirements will be updated as new PyTorch versions are released.) | ||||
| These requirements can be checked with the ABI checker (see below). | ||||
| 
 | ||||
| ### macOS | ||||
| 
 | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - macOS deployment target 15.0. | ||||
| - Metal 3.0 (`-std=metal3.0`). | ||||
| 
 | ||||
| The ABI3 requirement can be checked with the ABI checker (see below). | ||||
| 
 | ||||
| ### ABI checker | ||||
| 
 | ||||
| The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||
| [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||
| 
 | ||||
| ```bash | ||||
| 
 | ||||
| $ cargo install kernel-abi-check | ||||
| $ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so | ||||
| 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 | ||||
| ✅ No compatibility issues found | ||||
| ``` | ||||
| 
 | ||||
| ## Torch extension | ||||
| 
 | ||||
| @ -98,10 +132,20 @@ requirements: | ||||
| - The `forward` method has a signature that is compatible with the | ||||
|   `forward` method that it is extending. | ||||
| 
 | ||||
| There are two exceptions to the _no class variables rule_: | ||||
| 
 | ||||
| 1. The `has_backward` variable can be used to indicate whether the layer has | ||||
|    a backward pass implemented (`True` when absent). | ||||
| 2. The `can_torch_compile` variable can be used to indicate whether the layer | ||||
|    supports `torch.compile` (`False` when absent). | ||||
| 
 | ||||
| This is an example of a pure layer: | ||||
| 
 | ||||
| ```python | ||||
| class SiluAndMul(nn.Module): | ||||
|     # This layer does not implement backward. | ||||
|     has_backward: bool = False | ||||
| 
 | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         d = x.shape[-1] // 2 | ||||
|         output_shape = x.shape[:-1] + (d,) | ||||
							
								
								
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,323 @@ | ||||
| # Layers | ||||
|  | ||||
| A kernel can provide layers in addition to kernel functions. A layer from | ||||
| the Hub can replace the `forward` method of an existing layer for a certain | ||||
| device type. This makes it possible to provide more performant kernels for | ||||
| existing layers. | ||||
|  | ||||
| See [Kernel requirements](kernel-requirements.md) for more information on the | ||||
| requirements of Hub layers. | ||||
|  | ||||
| ## Making a layer extensible with kernels from the hub | ||||
|  | ||||
| ### Using a decorator | ||||
|  | ||||
| A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||
| decorator. For example: | ||||
|  | ||||
| ```python | ||||
| @use_kernel_forward_from_hub("SiluAndMul") | ||||
| class SiluAndMul(nn.Module): | ||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||
|         d = input.shape[-1] // 2 | ||||
|         return F.silu(input[..., :d]) * input[..., d:] | ||||
| ``` | ||||
|  | ||||
| The decorator does not change the behavior of the class -- it annotates | ||||
| the class with the given name (here `SiluAndMul`). The `kernelize` function | ||||
| described below uses this name to look up kernels for the layer. | ||||
|  | ||||
| ### External layers | ||||
|  | ||||
| An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||
| decorator can be made extensible using the `replace_kernel_forward_from_hub` | ||||
| function: | ||||
|  | ||||
| ```python | ||||
| from somelibrary import SiluAndMul | ||||
|  | ||||
| replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||
| ``` | ||||
|  | ||||
| **Warning:** we strongly recommend using layers with a decorator, since | ||||
| it signifies that the maintainer intends to keep the `forward` signature | ||||
| compatible with layers from the hub. | ||||
|  | ||||
| ## Kernelizing a model | ||||
|  | ||||
| A model will not use Hub kernels by default, even if it contains extensible | ||||
| layers. To enable the use of Hub kernels in the model, it needs to be | ||||
| 'kernelized' using the `kernelize` function. This function traverses the | ||||
| model graph and replaces the `forward` methods of extensible layers for which | ||||
| Hub kernels are registered. `kernelize` can be used as follows: | ||||
|  | ||||
| ```python | ||||
| model = MyModel(...) | ||||
| model = kernelize(model, mode=Mode.INFERENCE) | ||||
| ``` | ||||
|  | ||||
| The `kernelize` function modifies the model in-place, the model itself is | ||||
| returned as a convenience. The `mode` specifies that the model will be used | ||||
| in inference. Similarly, you can ask `kernelize` to prepare the model for | ||||
| training: | ||||
|  | ||||
| ```python | ||||
| model = MyModel(...) | ||||
| model = kernelize(model, mode=Mode.TRAINING) | ||||
| ``` | ||||
|  | ||||
| A model that is kernelized for training can also be used for inference, but | ||||
| not the other way around. If you want to change the mode of the kernelized | ||||
| model, you can just run `kernelize` on the model again with the new mode. | ||||
|  | ||||
| If you want to compile a model with `torch.compile`, this should be indicated | ||||
| in the mode as well. You can do this by combining `Mode.INFERENCE` or | ||||
| `Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator: | ||||
|  | ||||
| ```python | ||||
| model = MyModel(...) | ||||
|  | ||||
| # Inference | ||||
| model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE) | ||||
|  | ||||
| # Training | ||||
| model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
| ``` | ||||
|  | ||||
| ### Kernel device | ||||
|  | ||||
| Kernels can be registered per device type. For instance, separate `cuda` and | ||||
| `metal` kernels could be registered for the name `SiluAndMul`. By default, | ||||
| `kernelize` will try to infer the device type from the model's parameters. | ||||
| You can pass the device type to `kernelize` if the device type cannot be | ||||
| inferred (e.g. because the model has no parameters): | ||||
|  | ||||
| ```python | ||||
| model = MyModel(...) | ||||
| model = kernelize(model, device="cuda", mode=Mode.INFERENCE) | ||||
| ``` | ||||
|  | ||||
| ### Fallback `forward` | ||||
|  | ||||
| If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered | ||||
| kernel does not support backward passes or `torch.compile` respectively, | ||||
| `kernelize` will fall back to the original, non-kernelized, layer. You | ||||
| can let `kernelize` raise an exception instead by using `use_fallback=False`: | ||||
|  | ||||
| ```python | ||||
| model = MyModel(...) | ||||
| model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False) | ||||
| ``` | ||||
|  | ||||
| This can be useful if you want to guarantee that Hub kernels are used. | ||||
|  | ||||
| ### Inspecting which kernels are used | ||||
|  | ||||
| The kernels that are used are logged at the `INFO` level by `kernelize`. | ||||
| See the [Python logging](https://docs.python.org/3/library/logging.html) | ||||
| documentation for information on how to configure logging. | ||||
|  | ||||
| ## Registering a hub kernel for a layer | ||||
|  | ||||
| `kernelize` relies on kernel mappings to find Hub kernels for layers. | ||||
| Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on | ||||
| the Hub. For example: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|         ), | ||||
|         "rocm": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| You can register such a mapping using `register_kernel_mapping`: | ||||
|  | ||||
| ```python | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
|  | ||||
| This will register the kernel mapping in the current context, which is | ||||
| normally global. It is recommended to scope the mapping to where it is | ||||
| used with the `use_kernel_mapping` context manager: | ||||
|  | ||||
| ```python | ||||
| with use_kernel_mapping(kernel_layer_mapping): | ||||
|     # Use the layer for which the mapping is applied. | ||||
|     model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
| ``` | ||||
|  | ||||
| This ensures that the mapping is not active anymore outside the | ||||
| `with`-scope. | ||||
|  | ||||
| ### Using version bounds | ||||
|  | ||||
| Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||
| You can specify which version of the kernel to download using Python version | ||||
| specifiers: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             version=">=0.0.4,<0.1.0", | ||||
|         ), | ||||
|         "rocm": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             version=">=0.0.4,<0.1.0", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| This will get the layer from latest kernel tagged `v0.0.z` where `z` is at | ||||
| least 4. It is strongly recommended to specify a version bound, since a | ||||
| kernel author might push incompatible changes to the `main` branch. | ||||
|  | ||||
| ### Registering kernels for specific modes | ||||
|  | ||||
| You might want to register two different kernels for a particular layer, | ||||
| where one kernel is optimized for a specific mode. You can do so by | ||||
| registering layer repositories for specific modes. For example: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": { | ||||
|           Mode.INFERENCE: LayerRepository( | ||||
|               repo_id="kernels-community/activation-inference-optimized", | ||||
|               layer_name="SiluAndMul", | ||||
|           ), | ||||
|           Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository( | ||||
|               repo_id="kernels-community/activation-training-optimized", | ||||
|               layer_name="SiluAndMul", | ||||
|           ), | ||||
|       } | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| The `kernelize` function will attempt to use the following registered | ||||
| kernels for a given mode: | ||||
|  | ||||
| - `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` → | ||||
|   `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||
| - `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` → | ||||
|   `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||
| - `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||
| - `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||
|  | ||||
| `Mode.FALLBACK` is a special mode that is used when no other mode matches. It | ||||
| is also used when a kernel is registered without a mode, as described in the | ||||
| previous section. | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": { | ||||
|             Mode.FALLBACK: LayerRepository( | ||||
|                 repo_id="kernels-community/activation", | ||||
|                 layer_name="SiluAndMul", | ||||
|             ), | ||||
|             Mode.INFERENCE: LayerRepository( | ||||
|                 repo_id="kernels-community/activation-inference-optimized", | ||||
|                 layer_name="SiluAndMul", | ||||
|             ), | ||||
|             Mode.TRAINING: LayerRepository( | ||||
|                 repo_id="kernels-community/activation-training-optimized", | ||||
|                 layer_name="SiluAndMul", | ||||
|             ), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and | ||||
| `Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel, | ||||
| since the other kernels do not support `torch.compile`. | ||||
|  | ||||
| ### Registering kernels for specific CUDA capabilities | ||||
|  | ||||
| Some kernels only work with newer CUDA architectures. For instance, some | ||||
| kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels` | ||||
| supports registering layers for a range of CUDA capabilities. To do so, | ||||
| you need to register the layer for a `Device` with type `cuda` and | ||||
| set the supported range of CUDA capabilities with using `CUDAProperties`: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         Device( | ||||
|             type="cuda", | ||||
|             properties=CUDAProperties( | ||||
|                 min_capability=75, max_capability=89 | ||||
|             ), | ||||
|         ): LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|         ), | ||||
|         Device( | ||||
|             type="cuda", | ||||
|             properties=CUDAProperties( | ||||
|                 min_capability=90, max_capability=sys.maxsize | ||||
|             ), | ||||
|         ): LayerRepository( | ||||
|             repo_id="kernels-community/activation-hopper", | ||||
|             layer_name="SiluAndMul", | ||||
|         ), | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| Capabilities behave as follows: | ||||
|  | ||||
| - The minimum and maximum capabilities are inclusive. | ||||
| - When a new kernel is registered with the same min/max capabilities as | ||||
|   an existing kernel, the new kernel will replace the old kernel. | ||||
| - When there are multiple kernels that support a capability, the kernel | ||||
|   with the smaller capability interval will be used. E.g. given: | ||||
|   - `KernelA` with `min_capability=80` and `max_capability=89`; | ||||
|   - `KernelB` with `min_capability=75` and `max_capability=89`; | ||||
|   - `kernelize` runs on a system with capability 8.6. | ||||
|  | ||||
|   Then `KernelA` will be used because the interval 80..89 is smaller | ||||
|   than 75..89. The motivation is that kernels with smaller ranges | ||||
|   tend to be more optimized for a specific set of GPUs. **This behavior | ||||
|   might still change in the future.** | ||||
|  | ||||
| ### Registering kernels for specific ROCm capabilities | ||||
|  | ||||
| Registering kernels for the ROCm architecture follows the exact same | ||||
| pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict | ||||
| a kernel to a range of ROCm capabilities. | ||||
|  | ||||
| ### Loading from a local repository for testing | ||||
|  | ||||
| The `LocalLayerRepository` class is provided to load a repository from | ||||
| a local directory. For example: | ||||
|  | ||||
| ```python | ||||
| with use_kernel_mapping( | ||||
|     { | ||||
|         "SiluAndMul": { | ||||
|             "cuda": LocalLayerRepository( | ||||
|                 repo_path="/home/daniel/kernels/activation", | ||||
|                 package_name="activation", | ||||
|                 layer_name="SiluAndMul", | ||||
|             ) | ||||
|         } | ||||
|     }, | ||||
|     inherit_mapping=False, | ||||
| ): | ||||
|     kernelize(linear, mode=Mode.INFERENCE) | ||||
| ``` | ||||
| @ -1,4 +1,4 @@ | ||||
| # Locking kernel versions | ||||
| # Locking kernel/layer versions | ||||
| 
 | ||||
| Projects that use `setuptools` can lock the kernel versions that should be | ||||
| used. First specify the accepted versions in `pyproject.toml` and make | ||||
| @ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" | ||||
| "kernels-community/activation" = ">=0.0.1" | ||||
| ``` | ||||
| 
 | ||||
| Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with | ||||
| the locked revisions. The locked revision will be used when loading a kernel with | ||||
| `get_locked_kernel`: | ||||
| 
 | ||||
| @ -26,9 +26,27 @@ activation = get_locked_kernel("kernels-community/activation") | ||||
| **Note:** the lock file is included in the package metadata, so it will only be visible | ||||
| to `kernels` after doing an (editable or regular) installation of your project. | ||||
| 
 | ||||
| ## Locked kernel layers | ||||
| 
 | ||||
| Locking is also supported for kernel layers. To use locked layers, register them | ||||
| with the `LockedLayerRepository` class: | ||||
| 
 | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LockedLayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
| ``` | ||||
| 
 | ||||
| ## Pre-downloading locked kernels | ||||
| 
 | ||||
| Locked kernels can be pre-downloaded by running `kernel download .` in your | ||||
| Locked kernels can be pre-downloaded by running `kernels download .` in your | ||||
| project directory. This will download the kernels to your local Hugging Face | ||||
| Hub cache. | ||||
| 
 | ||||
| @ -20,11 +20,11 @@ activation.gelu_fast(y, x) | ||||
| print("Kernel successfully executed") | ||||
|  | ||||
| # Check results | ||||
| expected = torch.tensor([ | ||||
|     [0.8408, 1.9551, 2.9961], | ||||
|     [4.0000, 5.0000, 6.0000], | ||||
|     [7.0000, 8.0000, 9.0000] | ||||
| ], device='cuda:0', dtype=torch.float16) | ||||
| expected = torch.tensor( | ||||
|     [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||
|     device="cuda:0", | ||||
|     dtype=torch.float16, | ||||
| ) | ||||
| assert torch.allclose(y, expected) | ||||
|  | ||||
| print("Calculated values are exact") | ||||
|  | ||||
							
								
								
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							| @ -51,30 +51,50 @@ | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "nixpkgs": { | ||||
|     "hf-nix": { | ||||
|       "inputs": { | ||||
|         "flake-compat": "flake-compat", | ||||
|         "flake-utils": "flake-utils_2", | ||||
|         "nixpkgs": "nixpkgs" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1737453259, | ||||
|         "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", | ||||
|         "owner": "danieldk", | ||||
|         "repo": "nixpkgs", | ||||
|         "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", | ||||
|         "lastModified": 1754038838, | ||||
|         "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=", | ||||
|         "owner": "huggingface", | ||||
|         "repo": "hf-nix", | ||||
|         "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "danieldk", | ||||
|         "ref": "outlines-v0.1.4-tgi", | ||||
|         "owner": "huggingface", | ||||
|         "repo": "hf-nix", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "nixpkgs": { | ||||
|       "locked": { | ||||
|         "lastModified": 1752785354, | ||||
|         "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=", | ||||
|         "owner": "nixos", | ||||
|         "repo": "nixpkgs", | ||||
|         "rev": "d38025438a6ee456758dc03188ca6873a415463b", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "nixos", | ||||
|         "repo": "nixpkgs", | ||||
|         "rev": "d38025438a6ee456758dc03188ca6873a415463b", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "root": { | ||||
|       "inputs": { | ||||
|         "flake-utils": "flake-utils", | ||||
|         "hf-nix": "hf-nix", | ||||
|         "nixpkgs": [ | ||||
|           "tgi-nix", | ||||
|           "hf-nix", | ||||
|           "nixpkgs" | ||||
|         ], | ||||
|         "tgi-nix": "tgi-nix" | ||||
|         ] | ||||
|       } | ||||
|     }, | ||||
|     "systems": { | ||||
| @ -106,27 +126,6 @@ | ||||
|         "repo": "default", | ||||
|         "type": "github" | ||||
|       } | ||||
|     }, | ||||
|     "tgi-nix": { | ||||
|       "inputs": { | ||||
|         "flake-compat": "flake-compat", | ||||
|         "flake-utils": "flake-utils_2", | ||||
|         "nixpkgs": "nixpkgs" | ||||
|       }, | ||||
|       "locked": { | ||||
|         "lastModified": 1741617161, | ||||
|         "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=", | ||||
|         "owner": "huggingface", | ||||
|         "repo": "text-generation-inference-nix", | ||||
|         "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925", | ||||
|         "type": "github" | ||||
|       }, | ||||
|       "original": { | ||||
|         "owner": "huggingface", | ||||
|         "ref": "kernels-0.2.0", | ||||
|         "repo": "text-generation-inference-nix", | ||||
|         "type": "github" | ||||
|       } | ||||
|     } | ||||
|   }, | ||||
|   "root": "root", | ||||
|  | ||||
							
								
								
									
										22
									
								
								flake.nix
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								flake.nix
									
									
									
									
									
								
							| @ -1,7 +1,7 @@ | ||||
| { | ||||
|   inputs = { | ||||
|     tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0"; | ||||
|     nixpkgs.follows = "tgi-nix/nixpkgs"; | ||||
|     hf-nix.url = "github:huggingface/hf-nix"; | ||||
|     nixpkgs.follows = "hf-nix/nixpkgs"; | ||||
|     flake-utils.url = "github:numtide/flake-utils"; | ||||
|   }; | ||||
|   outputs = | ||||
| @ -9,23 +9,28 @@ | ||||
|       self, | ||||
|       nixpkgs, | ||||
|       flake-utils, | ||||
|       tgi-nix, | ||||
|       hf-nix, | ||||
|     }: | ||||
|     flake-utils.lib.eachDefaultSystem ( | ||||
|       system: | ||||
|       let | ||||
|         pkgs = import nixpkgs { | ||||
|           inherit system; | ||||
|           inherit (tgi-nix.lib) config; | ||||
|           config = hf-nix.lib.config system; | ||||
|           overlays = [ | ||||
|             tgi-nix.overlays.default | ||||
|             hf-nix.overlays.default | ||||
|           ]; | ||||
|         }; | ||||
|       in | ||||
|       { | ||||
|         formatter = pkgs.nixfmt-rfc-style; | ||||
|         formatter = pkgs.nixfmt-tree; | ||||
|         packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {}; | ||||
|         devShells = with pkgs; rec { | ||||
|           default = mkShell { | ||||
|             nativeBuildInputs = [ | ||||
|               # For hf-doc-builder. | ||||
|               nodejs | ||||
|             ]; | ||||
|             buildInputs = | ||||
|               [ | ||||
|                 black | ||||
| @ -34,10 +39,15 @@ | ||||
|                 ruff | ||||
|               ] | ||||
|               ++ (with python3.pkgs; [ | ||||
|                 docutils | ||||
|                 huggingface-hub | ||||
|                 (callPackage ./nix/kernel-abi-check.nix {}) | ||||
|                 mktestdocs | ||||
|                 pytest | ||||
|                 pytest-benchmark | ||||
|                 pyyaml | ||||
|                 torch | ||||
|                 types-pyyaml | ||||
|                 venvShellHook | ||||
|               ]); | ||||
|  | ||||
|  | ||||
							
								
								
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| { | ||||
|   buildPythonPackage, | ||||
|   fetchPypi, | ||||
|   rustPlatform, | ||||
| }: | ||||
|  | ||||
| buildPythonPackage rec { | ||||
|   pname = "kernel-abi-check"; | ||||
|   version = "0.6.2"; | ||||
|  | ||||
|   src = fetchPypi { | ||||
|     inherit version; | ||||
|     pname = "kernel_abi_check"; | ||||
|     hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys="; | ||||
|   }; | ||||
|  | ||||
|   cargoDeps = rustPlatform.fetchCargoVendor { | ||||
|     inherit pname version src sourceRoot; | ||||
|     hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A="; | ||||
|   }; | ||||
|  | ||||
|   sourceRoot = "kernel_abi_check-${version}/bindings/python"; | ||||
|  | ||||
|   pyproject = true; | ||||
|  | ||||
|   nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ]; | ||||
| } | ||||
| @ -1,6 +1,6 @@ | ||||
| [project] | ||||
| name = "kernels" | ||||
| version = "0.3.0" | ||||
| version = "0.10.2.dev0" | ||||
| description = "Download compute kernels" | ||||
| authors = [ | ||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||
| @ -8,13 +8,14 @@ authors = [ | ||||
|   { name = "David Holtz", email = "david@huggingface.co" }, | ||||
|   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, | ||||
| ] | ||||
| license = { text = "Apache-2.0" } | ||||
| readme = "README.md" | ||||
| requires-python = ">= 3.9" | ||||
| dependencies = [ | ||||
|   "huggingface-hub>=0.26.3", | ||||
|   "packaging>=24.2", | ||||
|   "tomli>=2.0.1; python_version<'3.11'", | ||||
|   "torch>=2.5", | ||||
|   "huggingface_hub>=0.26.0,<2.0", | ||||
|   "packaging>=20.0", | ||||
|   "pyyaml>=6", | ||||
|   "tomli>=2.0; python_version<'3.11'", | ||||
| ] | ||||
|  | ||||
| [build-system] | ||||
| @ -23,10 +24,20 @@ build-backend = "setuptools.build_meta" | ||||
|  | ||||
| [dependency-groups] | ||||
| dev = [ | ||||
|   "mypy == 1.14.1", | ||||
|   "pytest >=8", | ||||
|   "mktestdocs>=0.2.5", | ||||
|   "mypy>=1.15.0", | ||||
|   "pytest>=8", | ||||
|   # Whatever version is compatible with pytest. | ||||
|   "pytest-benchmark", | ||||
|   "torch>=2.5", | ||||
|   "types-pyyaml" | ||||
| ] | ||||
|  | ||||
| [project.optional-dependencies] | ||||
| abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"] | ||||
| torch = ["torch"] | ||||
| docs = [ | ||||
|   "hf-doc-builder", | ||||
| ] | ||||
|  | ||||
| [project.scripts] | ||||
| @ -35,6 +46,9 @@ kernels = "kernels.cli:main" | ||||
| [project.entry-points."egg_info.writers"] | ||||
| "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||
|  | ||||
| [tool.isort] | ||||
| profile = "black" | ||||
| line_length = 119 | ||||
|  | ||||
| [tool.ruff] | ||||
| exclude = [ | ||||
| @ -61,4 +75,4 @@ line-length = 119 | ||||
| # Ignored rules: | ||||
| # "E501" -> line length violation | ||||
| lint.ignore = ["E501"] | ||||
| lint.select = ["E", "F", "I", "W"] | ||||
| lint.select = ["E", "F", "W"] | ||||
|  | ||||
							
								
								
									
										9
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | ||||
| [pytest] | ||||
| markers = | ||||
|     cuda_only: marks tests that should only hosts with CUDA GPUs | ||||
|     rocm_only: marks tests that should only run on hosts with ROCm GPUs | ||||
|     darwin_only: marks tests that should only run on macOS | ||||
|     xpu_only: marks tests that should only run on hosts with Intel XPUs | ||||
|     npu_only: marks tests that should only run on Ascend NPUs | ||||
|     token: enable tests that require a write token | ||||
|     is_staging_test: Marks tests that should only run on a staging environment | ||||
| @ -1,23 +1,46 @@ | ||||
| import importlib.metadata | ||||
|  | ||||
| __version__ = importlib.metadata.version("kernels") | ||||
|  | ||||
| from kernels.layer import ( | ||||
|     CUDAProperties, | ||||
|     Device, | ||||
|     LayerRepository, | ||||
|     LocalLayerRepository, | ||||
|     LockedLayerRepository, | ||||
|     Mode, | ||||
|     kernelize, | ||||
|     register_kernel_mapping, | ||||
|     replace_kernel_forward_from_hub, | ||||
|     use_kernel_forward_from_hub, | ||||
|     use_kernel_mapping, | ||||
| ) | ||||
| from kernels.utils import ( | ||||
|     get_kernel, | ||||
|     get_local_kernel, | ||||
|     get_locked_kernel, | ||||
|     has_kernel, | ||||
|     install_kernel, | ||||
|     load_kernel, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     "get_kernel", | ||||
|     "get_locked_kernel", | ||||
|     "load_kernel", | ||||
|     "install_kernel", | ||||
|     "use_kernel_forward_from_hub", | ||||
|     "register_kernel_mapping", | ||||
|     "LayerRepository", | ||||
|     "__version__", | ||||
|     "CUDAProperties", | ||||
|     "Device", | ||||
|     "LayerRepository", | ||||
|     "LocalLayerRepository", | ||||
|     "LockedLayerRepository", | ||||
|     "Mode", | ||||
|     "get_kernel", | ||||
|     "get_local_kernel", | ||||
|     "get_locked_kernel", | ||||
|     "has_kernel", | ||||
|     "install_kernel", | ||||
|     "kernelize", | ||||
|     "load_kernel", | ||||
|     "register_kernel_mapping", | ||||
|     "replace_kernel_forward_from_hub", | ||||
|     "use_kernel_forward_from_hub", | ||||
|     "use_kernel_mapping", | ||||
| ] | ||||
|  | ||||
							
								
								
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,200 @@ | ||||
| # AVL-balanced interval trees. We could use the intervaltree | ||||
| # packages, but it seems unmaintained and does not have type | ||||
| # annotations. | ||||
|  | ||||
| from typing import Generic, List, Optional, Tuple, TypeVar | ||||
|  | ||||
| T = TypeVar("T") | ||||
|  | ||||
|  | ||||
| class _Node(Generic[T]): | ||||
|     """A node in the interval tree.""" | ||||
|  | ||||
|     def __init__(self, start: int, end: int, data: T): | ||||
|         self.start: int = start | ||||
|         self.end: int = end | ||||
|         self.data: T = data | ||||
|         self.max_end: int = end | ||||
|         self.left: Optional["_Node[T]"] = None | ||||
|         self.right: Optional["_Node[T]"] = None | ||||
|         self.height: int = 1 | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f"Node({self.start}, {self.end})" | ||||
|  | ||||
|  | ||||
| class IntervalTree(Generic[T]): | ||||
|     """A data structure to hold and query (unique) intervals.""" | ||||
|  | ||||
|     root: Optional[_Node[T]] | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.root = None | ||||
|  | ||||
|     def insert(self, start: int, end: int, data: T) -> None: | ||||
|         """ | ||||
|         Inserts a new interval into the tree. | ||||
|  | ||||
|         Args: | ||||
|             start: The starting point of the interval. | ||||
|             end: The ending point of the interval. | ||||
|             data: The data associated with this interval. | ||||
|         """ | ||||
|         self.root = self._insert(self.root, start, end, data) | ||||
|  | ||||
|     def _get_height(self, node: Optional[_Node[T]]) -> int: | ||||
|         if not node: | ||||
|             return 0 | ||||
|         return node.height | ||||
|  | ||||
|     def _get_balance(self, node: Optional[_Node[T]]) -> int: | ||||
|         if not node: | ||||
|             return 0 | ||||
|         return self._get_height(node.left) - self._get_height(node.right) | ||||
|  | ||||
|     def _update_node_attributes(self, node: _Node[T]) -> None: | ||||
|         node.height = 1 + max(self._get_height(node.left), self._get_height(node.right)) | ||||
|         node.max_end = node.end | ||||
|         if node.left: | ||||
|             node.max_end = max(node.max_end, node.left.max_end) | ||||
|         if node.right: | ||||
|             node.max_end = max(node.max_end, node.right.max_end) | ||||
|  | ||||
|     def _right_rotate(self, y: _Node[T]) -> _Node[T]: | ||||
|         """Performs a right rotation.""" | ||||
|         x = y.left | ||||
|         assert x is not None | ||||
|         T2 = x.right | ||||
|  | ||||
|         x.right = y | ||||
|         y.left = T2 | ||||
|  | ||||
|         self._update_node_attributes(y) | ||||
|         self._update_node_attributes(x) | ||||
|  | ||||
|         return x | ||||
|  | ||||
|     def _left_rotate(self, x: _Node[T]) -> _Node[T]: | ||||
|         """Performs a left rotation.""" | ||||
|         y = x.right | ||||
|         assert y is not None | ||||
|         T2 = y.left | ||||
|  | ||||
|         y.left = x | ||||
|         x.right = T2 | ||||
|  | ||||
|         self._update_node_attributes(x) | ||||
|         self._update_node_attributes(y) | ||||
|  | ||||
|         return y | ||||
|  | ||||
|     def _insert( | ||||
|         self, node: Optional[_Node[T]], start: int, end: int, data: T | ||||
|     ) -> _Node[T]: | ||||
|         """Recursive helper to insert a new node and balance the tree.""" | ||||
|         if not node: | ||||
|             return _Node(start, end, data) | ||||
|  | ||||
|         # Replace the data if the interval already exists. | ||||
|         if start == node.start and end == node.end: | ||||
|             node.data = data | ||||
|             return node | ||||
|  | ||||
|         if start < node.start: | ||||
|             node.left = self._insert(node.left, start, end, data) | ||||
|         else: | ||||
|             node.right = self._insert(node.right, start, end, data) | ||||
|  | ||||
|         self._update_node_attributes(node) | ||||
|  | ||||
|         balance = self._get_balance(node) | ||||
|  | ||||
|         # Left Left Case | ||||
|         if balance > 1 and node.left and start < node.left.start: | ||||
|             return self._right_rotate(node) | ||||
|  | ||||
|         # Right Right Case | ||||
|         if balance < -1 and node.right and start >= node.right.start: | ||||
|             return self._left_rotate(node) | ||||
|  | ||||
|         # Left Right Case | ||||
|         if balance > 1 and node.left and start >= node.left.start: | ||||
|             node.left = self._left_rotate(node.left) | ||||
|             return self._right_rotate(node) | ||||
|  | ||||
|         # Right Left Case | ||||
|         if balance < -1 and node.right and start < node.right.start: | ||||
|             node.right = self._right_rotate(node.right) | ||||
|             return self._left_rotate(node) | ||||
|  | ||||
|         return node | ||||
|  | ||||
|     def search(self, point: int) -> List[T]: | ||||
|         """ | ||||
|         Searches for all intervals that contain the given point. | ||||
|  | ||||
|         Args: | ||||
|             point: The point to search for. | ||||
|  | ||||
|         Returns: | ||||
|             A list of data items from all matching intervals. | ||||
|         """ | ||||
|         results: List[T] = [] | ||||
|         self._search(self.root, point, results) | ||||
|         return results | ||||
|  | ||||
|     def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None: | ||||
|         """Recursive helper to find all overlapping intervals.""" | ||||
|         if node is None or point > node.max_end: | ||||
|             return | ||||
|  | ||||
|         if node.left: | ||||
|             self._search(node.left, point, results) | ||||
|  | ||||
|         if node.start <= point <= node.end: | ||||
|             results.append(node.data) | ||||
|  | ||||
|         if point >= node.start and node.right: | ||||
|             self._search(node.right, point, results) | ||||
|  | ||||
|     def find_smallest_interval(self, point: int) -> Optional[T]: | ||||
|         """ | ||||
|         Finds the item with the most specific (smallest) range for a given point. | ||||
|  | ||||
|         Args: | ||||
|             point: The capability to look up. | ||||
|  | ||||
|         Returns: | ||||
|             The data of the best-matching item, or None if no match is found. | ||||
|         """ | ||||
|         matches: List[Tuple[int, int, T]] = [] | ||||
|         self._find_with_intervals(self.root, point, matches) | ||||
|  | ||||
|         if not matches: | ||||
|             return None | ||||
|  | ||||
|         # Return the smallest interval, sort by memory location when | ||||
|         # there are multiple matches with the same interval size. This | ||||
|         # is just to ensure that we can compare against a trivial | ||||
|         # implementation in tests. | ||||
|         best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2]))) | ||||
|         return best_match[2] | ||||
|  | ||||
|     def _find_with_intervals( | ||||
|         self, | ||||
|         node: Optional[_Node[T]], | ||||
|         point: int, | ||||
|         results: List[Tuple[int, int, T]], | ||||
|     ) -> None: | ||||
|         """A modified search that collects interval ranges along with data.""" | ||||
|         if node is None or point > node.max_end: | ||||
|             return | ||||
|  | ||||
|         if node.left: | ||||
|             self._find_with_intervals(node.left, point, results) | ||||
|  | ||||
|         if node.start <= point <= node.end: | ||||
|             results.append((node.start, node.end, node.data)) | ||||
|  | ||||
|         if point >= node.start and node.right: | ||||
|             self._find_with_intervals(node.right, point, results) | ||||
							
								
								
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,751 @@ | ||||
| # coding=utf-8 | ||||
| # Copyright 2021 The HuggingFace Team. All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
|  | ||||
| # Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py | ||||
|  | ||||
| import re | ||||
|  | ||||
| # Re pattern to catch things inside ` ` in :obj:`thing`. | ||||
| _re_obj = re.compile(r":obj:`([^`]+)`") | ||||
| # Re pattern to catch things inside ` ` in :math:`thing`. | ||||
| _re_math = re.compile(r":math:`([^`]+)`") | ||||
| # Re pattern to catch things between single backquotes. | ||||
| _re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)") | ||||
| # Re pattern to catch things between double backquotes. | ||||
| _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)") | ||||
| # Re pattern to catch things inside ` ` in :func/class/meth:`thing`. | ||||
| _re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`") | ||||
|  | ||||
|  | ||||
| def convert_rst_formatting(text): | ||||
|     """ | ||||
|     Convert rst syntax for formatting to markdown in a given text. | ||||
|     """ | ||||
|     # Remove :class:, :func: and :meth: markers. To code-links and put double backquotes | ||||
|     # (to not be caught by the italic conversion). | ||||
|     text = _re_func_class.sub(r"[``\1``]", text) | ||||
|     # Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes | ||||
|     # (to not be caught by the italic conversion). | ||||
|     text = _re_obj.sub(r"``\1``", text) | ||||
|     # Remove :math: markers. | ||||
|     text = _re_math.sub(r"\\\\(\1\\\\)", text) | ||||
|     # Convert content in single backquotes to italic. | ||||
|     text = _re_single_backquotes.sub(r"\1*\2*\3", text) | ||||
|     # Convert content in double backquotes to single backquotes. | ||||
|     text = _re_double_backquotes.sub(r"\1`\2`\3", text) | ||||
|     # Remove remaining :: | ||||
|     text = re.sub(r"::\n", "", text) | ||||
|  | ||||
|     # Remove new lines inside blocks in backsticks as they will be kept. | ||||
|     lines = text.split("\n") | ||||
|     in_code = False | ||||
|     text = None | ||||
|     for line in lines: | ||||
|         if in_code: | ||||
|             splits = line.split("`") | ||||
|             in_code = len(splits) > 1 and len(splits) % 2 == 1 | ||||
|             if len(splits) == 1: | ||||
|                 # Some forgotten lone backstick | ||||
|                 text += "\n" + line | ||||
|             else: | ||||
|                 text += " " + line.lstrip() | ||||
|         else: | ||||
|             if text is not None: | ||||
|                 text += "\n" + line | ||||
|             else: | ||||
|                 text = line | ||||
|             splits = line.split("`") | ||||
|             in_code = len(splits) % 2 == 0 | ||||
|     return text | ||||
|  | ||||
|  | ||||
| # Re pattern to catch description and url in links of the form `description <url>`_. | ||||
| _re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+") | ||||
| # Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_. | ||||
| _re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`") | ||||
| # Re pattern to catch reference in links of the form :doc:`reference`. | ||||
| _re_simple_doc = re.compile(r":doc:`([^`<]*)`") | ||||
| # Re pattern to catch description and reference in links of the form :doc:`description <reference>`. | ||||
| _re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`") | ||||
| # Re pattern to catch reference in links of the form :ref:`reference`. | ||||
| _re_simple_ref = re.compile(r":ref:`([^`<]*)`") | ||||
| # Re pattern to catch description and reference in links of the form :ref:`description <reference>`. | ||||
| _re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`") | ||||
|  | ||||
|  | ||||
| def convert_rst_links(text, page_info): | ||||
|     """ | ||||
|     Convert the rst links in text to markdown. | ||||
|     """ | ||||
|     if "package_name" not in page_info: | ||||
|         raise ValueError("`page_info` must contain at least the package_name.") | ||||
|     package_name = page_info["package_name"] | ||||
|     version = page_info.get("version", "main") | ||||
|     language = page_info.get("language", "en") | ||||
|     no_prefix = page_info.get("no_prefix", False) | ||||
|  | ||||
|     prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/" | ||||
|     # Links of the form :doc:`page` | ||||
|     text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text) | ||||
|     # Links of the form :doc:`text <page>` | ||||
|     text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text) | ||||
|  | ||||
|     if "page" in page_info and not no_prefix: | ||||
|         page = str(page_info["page"]) | ||||
|         if page.endswith(".html"): | ||||
|             page = page[:-5] | ||||
|         prefix = f"{prefix}{page}" | ||||
|     else: | ||||
|         prefix = "" | ||||
|     # Refs of the form :ref:`page` | ||||
|     text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text) | ||||
|     # Refs of the form :ref:`text <page>` | ||||
|     text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text) | ||||
|  | ||||
|     # Links with a prefix | ||||
|     # TODO: when it exists, use the API to deal with prefix links properly. | ||||
|     prefix = f"https://github.com/huggingface/{package_name}/tree/main/" | ||||
|     text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text) | ||||
|     # Other links | ||||
|     text = _re_links.sub(r"[\1](\2)", text) | ||||
|     # Relative links or Transformers links need to remove the .html | ||||
|     if ( | ||||
|         "(https://https://huggingface.co/" in text | ||||
|         or re.search(r"\(\.+/", text) is not None | ||||
|     ): | ||||
|         text = text.replace(".html", "") | ||||
|     return text | ||||
|  | ||||
|  | ||||
| # Re pattern that catches examples blocks of the form `Example::`. | ||||
| _re_example = re.compile(r"^\s*(\S.*)::\s*$") | ||||
| # Re pattern that catches rst blocks of the form `.. block_name::`. | ||||
| _re_block = re.compile(r"^\s*\.\.\s+(\S+)::") | ||||
| # Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`. | ||||
| _re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$") | ||||
|  | ||||
|  | ||||
| def is_empty_line(line): | ||||
|     return len(line) == 0 or line.isspace() | ||||
|  | ||||
|  | ||||
| def find_indent(line): | ||||
|     """ | ||||
|     Returns the number of spaces that start a line indent. | ||||
|     """ | ||||
|     search = re.search(r"^(\s*)(?:\S|$)", line) | ||||
|     if search is None: | ||||
|         return 0 | ||||
|     return len(search.groups()[0]) | ||||
|  | ||||
|  | ||||
| _re_rst_option = re.compile(r"^\s*:(\S+):(.*)$") | ||||
|  | ||||
|  | ||||
| def convert_special_chars(text): | ||||
|     """ | ||||
|     Converts { and < that have special meanings in MDX. | ||||
|     """ | ||||
|     text = text.replace("{", "&lcub;") | ||||
|     # We don't want to replace those by the HTML code, so we temporarily set them at LTHTML | ||||
|     text = re.sub( | ||||
|         r"<(img|br|hr|Youtube)", r"LTHTML\1", text | ||||
|     )  # html void elements with no closing counterpart | ||||
|     _re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL) | ||||
|     while _re_lt_html.search(text): | ||||
|         text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text) | ||||
|     text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text) | ||||
|     text = text.replace("LTHTML", "<") | ||||
|     return text | ||||
|  | ||||
|  | ||||
| def parse_options(block_content): | ||||
|     """ | ||||
|     Parses the option in some rst block content. | ||||
|     """ | ||||
|     block_lines = block_content.split("\n") | ||||
|     block_indent = find_indent(block_lines[0]) | ||||
|     current_option = None | ||||
|     result = {} | ||||
|     for line in block_lines: | ||||
|         if _re_rst_option.search(line) is not None: | ||||
|             current_option, value = _re_rst_option.search(line).groups() | ||||
|             result[current_option] = value.lstrip() | ||||
|         elif find_indent(line) > block_indent: | ||||
|             result[current_option] += " " + line.lstrip() | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def apply_min_indent(text, min_indent): | ||||
|     """ | ||||
|     Make sure all lines in a text are have a minimum indentation. | ||||
|  | ||||
|     Args: | ||||
|         text (`str`): The text to treat. | ||||
|         min_indent (`int`): The minimal indentation. | ||||
|  | ||||
|     Returns: | ||||
|         `str`: The processed text. | ||||
|     """ | ||||
|     lines = text.split("\n") | ||||
|     idx = 0 | ||||
|     while idx < len(lines): | ||||
|         if is_empty_line(lines[idx]): | ||||
|             idx += 1 | ||||
|             continue | ||||
|         indent = find_indent(lines[idx]) | ||||
|         if indent < min_indent: | ||||
|             while idx < len(lines) and ( | ||||
|                 find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]) | ||||
|             ): | ||||
|                 if not is_empty_line(lines[idx]): | ||||
|                     lines[idx] = " " * (min_indent - indent) + lines[idx] | ||||
|                 idx += 1 | ||||
|         else: | ||||
|             idx += 1 | ||||
|  | ||||
|     return "\n".join(lines) | ||||
|  | ||||
|  | ||||
| def convert_rst_blocks(text, page_info): | ||||
|     """ | ||||
|     Converts rst special blocks (examples, notes) into MDX. | ||||
|     """ | ||||
|     if "package_name" not in page_info: | ||||
|         raise ValueError("`page_info` must contain at least the package_name.") | ||||
|     package_name = page_info["package_name"] | ||||
|     version = page_info.get("version", "main") | ||||
|     language = page_info.get("language", "en") | ||||
|  | ||||
|     lines = text.split("\n") | ||||
|     idx = 0 | ||||
|     new_lines = [] | ||||
|     while idx < len(lines): | ||||
|         block_type = None | ||||
|         block_info = None | ||||
|         if _re_block.search(lines[idx]) is not None: | ||||
|             block_type = _re_block.search(lines[idx]).groups()[0] | ||||
|             if _re_block_info.search(lines[idx]) is not None: | ||||
|                 block_info = _re_block_info.search(lines[idx]).groups()[0] | ||||
|         elif _re_example.search(lines[idx]) is not None: | ||||
|             block_type = "code-block-example" | ||||
|             block_info = "python" | ||||
|             example_name = _re_example.search(lines[idx]).groups()[0] | ||||
|             new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n") | ||||
|         elif lines[idx].strip() == "..": | ||||
|             block_type = "comment" | ||||
|         elif lines[idx].strip() == "::": | ||||
|             block_type = "code-block" | ||||
|  | ||||
|         if block_type is not None: | ||||
|             block_indent = find_indent(lines[idx]) | ||||
|             # Find the next nonempty line | ||||
|             idx += 1 | ||||
|             while idx < len(lines) and is_empty_line(lines[idx]): | ||||
|                 idx += 1 | ||||
|             # Grab the indent of the return line, this block will stop when we unindent under it (or has already) | ||||
|             example_indent = ( | ||||
|                 find_indent(lines[idx]) if idx < len(lines) else block_indent | ||||
|             ) | ||||
|  | ||||
|             if example_indent == block_indent: | ||||
|                 block_content = "" | ||||
|             else: | ||||
|                 block_lines = [] | ||||
|                 while idx < len(lines) and ( | ||||
|                     is_empty_line(lines[idx]) | ||||
|                     or find_indent(lines[idx]) >= example_indent | ||||
|                 ): | ||||
|                     block_lines.append(lines[idx][example_indent:]) | ||||
|                     idx += 1 | ||||
|                 block_content = "\n".join(block_lines) | ||||
|  | ||||
|             if block_type in ["code", "code-block"]: | ||||
|                 prefix = "```" if block_info is None else f"```{block_info}" | ||||
|                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n") | ||||
|             elif block_type == "code-block-example": | ||||
|                 prefix = f"<example>```{block_info}" | ||||
|                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>") | ||||
|             elif block_type == "note": | ||||
|                 new_lines.append( | ||||
|                     apply_min_indent( | ||||
|                         f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent | ||||
|                     ) | ||||
|                 ) | ||||
|             elif block_type == "warning": | ||||
|                 new_lines.append( | ||||
|                     apply_min_indent( | ||||
|                         "<Tip warning={true}>\n\n" | ||||
|                         + f"{block_content.strip()}\n\n</Tip>\n", | ||||
|                         block_indent, | ||||
|                     ) | ||||
|                 ) | ||||
|             elif block_type == "raw": | ||||
|                 new_lines.append(block_content.strip() + "\n") | ||||
|             elif block_type == "math": | ||||
|                 new_lines.append(f"$${block_content.strip()}$$\n") | ||||
|             elif block_type == "comment": | ||||
|                 new_lines.append(f"<!--{block_content.strip()}\n-->\n") | ||||
|             elif block_type == "autofunction": | ||||
|                 if block_info is not None: | ||||
|                     new_lines.append(f"[[autodoc]] {block_info}\n") | ||||
|             elif block_type == "autoclass": | ||||
|                 if block_info is not None: | ||||
|                     block = f"[[autodoc]] {block_info}\n" | ||||
|                     options = parse_options(block_content) | ||||
|                     if "special-members" in options: | ||||
|                         special_members = options["special-members"].split(", ") | ||||
|                         for special_member in special_members: | ||||
|                             block += f"    - {special_member}\n" | ||||
|                     if "members" in options: | ||||
|                         members = options["members"] | ||||
|                         if len(members) == 0: | ||||
|                             block += "    - all\n" | ||||
|                         else: | ||||
|                             for member in members.split(", "): | ||||
|                                 block += f"    - {member}\n" | ||||
|                     new_lines.append(block) | ||||
|             elif block_type == "image": | ||||
|                 options = parse_options(block_content) | ||||
|                 target = options.pop("target", None) | ||||
|                 if block_info is not None: | ||||
|                     options["src"] = block_info | ||||
|                 else: | ||||
|                     if target is None: | ||||
|                         raise ValueError("Image source not defined.") | ||||
|                     options["src"] = target | ||||
|                 # Adapt path | ||||
|                 options["src"] = options["src"].replace( | ||||
|                     "/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/" | ||||
|                 ) | ||||
|                 html_code = " ".join( | ||||
|                     [f'{key}="{value}"' for key, value in options.items()] | ||||
|                 ) | ||||
|                 new_lines.append(f"<img {html_code}/>\n") | ||||
|  | ||||
|             else: | ||||
|                 new_lines.append( | ||||
|                     f"{block_type},{block_info}\n{block_content.rstrip()}\n" | ||||
|                 ) | ||||
|  | ||||
|         else: | ||||
|             new_lines.append(lines[idx]) | ||||
|             idx += 1 | ||||
|  | ||||
|     return "\n".join(new_lines) | ||||
|  | ||||
|  | ||||
| # Re pattern that catches rst args blocks of the form `Parameters:`. | ||||
| _re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$") | ||||
| # Re pattern that catches return blocks of the form `Return:`. | ||||
| _re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$") | ||||
|  | ||||
|  | ||||
| def split_return_line(line): | ||||
|     """ | ||||
|     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. | ||||
|     """ | ||||
|     splits_on_colon = line.split(":") | ||||
|     idx = 1 | ||||
|     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: | ||||
|         idx += 2 | ||||
|     if idx >= len(splits_on_colon): | ||||
|         if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()): | ||||
|             return line, "" | ||||
|         return None, line | ||||
|     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) | ||||
|  | ||||
|  | ||||
| def split_raise_line(line): | ||||
|     """ | ||||
|     Split the raise line with format `SomeError some doc`. | ||||
|     """ | ||||
|     splits_on_colon = line.strip().split(" ") | ||||
|     error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:]) | ||||
|     if error_type and error_type[-1] == ":": | ||||
|         error_type = error_type[:-1] | ||||
|     return error_type, doc | ||||
|  | ||||
|  | ||||
| def split_arg_line(line): | ||||
|     """ | ||||
|     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. | ||||
|     """ | ||||
|     splits_on_colon = line.split(":") | ||||
|     idx = 1 | ||||
|     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: | ||||
|         idx += 2 | ||||
|     if idx >= len(splits_on_colon): | ||||
|         return line, "" | ||||
|     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) | ||||
|  | ||||
|  | ||||
| class InvalidRstDocstringError(ValueError): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| _re_parameters = re.compile( | ||||
|     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL | ||||
| ) | ||||
| _re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL) | ||||
|  | ||||
|  | ||||
| def parse_rst_docstring(docstring): | ||||
|     """ | ||||
|     Parses a docstring written in rst, in particular the list of arguments and the return type. | ||||
|     """ | ||||
|     lines = docstring.split("\n") | ||||
|     idx = 0 | ||||
|     while idx < len(lines): | ||||
|         # Parameters section | ||||
|         if _re_args.search(lines[idx]) is not None: | ||||
|             # Title of the section. | ||||
|             lines[idx] = "<parameters>\n" | ||||
|             # Find the next nonempty line | ||||
|             idx += 1 | ||||
|             while is_empty_line(lines[idx]): | ||||
|                 idx += 1 | ||||
|             # Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the | ||||
|             # Returns or Raises block. | ||||
|             param_indent = find_indent(lines[idx]) | ||||
|             while ( | ||||
|                 idx < len(lines) | ||||
|                 and find_indent(lines[idx]) == param_indent | ||||
|                 and _re_returns.search(lines[idx]) is None | ||||
|             ): | ||||
|                 intro, doc = split_arg_line(lines[idx]) | ||||
|                 # Line starting with a > after indent indicate a "section title" in the parameters. | ||||
|                 if intro.lstrip().startswith(">"): | ||||
|                     lines[idx] = intro.lstrip() | ||||
|                 else: | ||||
|                     lines[idx] = ( | ||||
|                         re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc | ||||
|                     ) | ||||
|                 idx += 1 | ||||
|                 while idx < len(lines) and ( | ||||
|                     is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent | ||||
|                 ): | ||||
|                     idx += 1 | ||||
|             lines.insert(idx, "</parameters>\n") | ||||
|             idx += 1 | ||||
|  | ||||
|         # Returns section | ||||
|         elif _re_returns.search(lines[idx]) is not None: | ||||
|             # tag is either `return` or `yield` | ||||
|             tag = _re_returns.match(lines[idx]).group(1).lower() | ||||
|             # Title of the section. | ||||
|             lines[idx] = f"<{tag}s>\n" | ||||
|             # Find the next nonempty line | ||||
|             idx += 1 | ||||
|             while is_empty_line(lines[idx]): | ||||
|                 idx += 1 | ||||
|  | ||||
|             # Grab the indent of the return line, this block will stop when we unindent under it. | ||||
|             return_indent = find_indent(lines[idx]) | ||||
|             raised_errors = [] | ||||
|             # The line may contain the return type. | ||||
|             if tag in ["return", "yield"]: | ||||
|                 return_type, return_description = split_return_line(lines[idx]) | ||||
|                 lines[idx] = return_description | ||||
|                 idx += 1 | ||||
|                 while idx < len(lines) and ( | ||||
|                     is_empty_line(lines[idx]) | ||||
|                     or find_indent(lines[idx]) >= return_indent | ||||
|                 ): | ||||
|                     idx += 1 | ||||
|             else: | ||||
|                 while idx < len(lines) and find_indent(lines[idx]) == return_indent: | ||||
|                     return_type, return_description = split_raise_line(lines[idx]) | ||||
|                     raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type) | ||||
|                     lines[idx] = "- " + raised_error + " -- " + return_description | ||||
|                     md_link = _re_md_link.match(raised_error) | ||||
|                     if md_link: | ||||
|                         raised_error = md_link[1] | ||||
|                         raised_error = re.sub( | ||||
|                             r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error | ||||
|                         ) | ||||
|                     if raised_error not in raised_errors: | ||||
|                         raised_errors.append(raised_error) | ||||
|                     idx += 1 | ||||
|                     while idx < len(lines) and ( | ||||
|                         is_empty_line(lines[idx]) | ||||
|                         or find_indent(lines[idx]) > return_indent | ||||
|                     ): | ||||
|                         idx += 1 | ||||
|  | ||||
|             lines.insert(idx, f"</{tag}s>\n") | ||||
|             idx += 1 | ||||
|  | ||||
|             # Return block finished, we insert the return type if one was specified | ||||
|             if tag in ["return", "yield"] and return_type is not None: | ||||
|                 lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n" | ||||
|             elif len(raised_errors) > 0: | ||||
|                 # raised errors | ||||
|                 lines[ | ||||
|                     idx - 1 | ||||
|                 ] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n" | ||||
|  | ||||
|         else: | ||||
|             idx += 1 | ||||
|  | ||||
|     result = "\n".join(lines) | ||||
|  | ||||
|     # combine multiple <parameters> blocks into one block | ||||
|     if result.count("<parameters>") > 1: | ||||
|         parameters_blocks = _re_parameters.findall(result) | ||||
|         parameters_blocks = [pb[0].strip() for pb in parameters_blocks] | ||||
|         parameters_str = "\n".join(parameters_blocks) | ||||
|         result = _re_parameters.sub("", result) | ||||
|         result += f"\n<parameters>{parameters_str}</parameters>\n" | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| _re_list = re.compile(r"^\s*(-|\*|\d+\.)\s") | ||||
| _re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$") | ||||
|  | ||||
|  | ||||
| def remove_indent(text): | ||||
|     """ | ||||
|     Remove indents in text, except the one linked to lists (or sublists). | ||||
|     """ | ||||
|     lines = text.split("\n") | ||||
|     # List of indents to remember for nested lists | ||||
|     current_indents = [] | ||||
|     # List of new indents to remember for nested lists | ||||
|     new_indents = [] | ||||
|     is_inside_code = False | ||||
|     code_indent = 0 | ||||
|     for idx, line in enumerate(lines): | ||||
|         # Line is an item in a list. | ||||
|         if _re_list.search(line) is not None: | ||||
|             indent = find_indent(line) | ||||
|             # Is it a new list / new level of nestedness? | ||||
|             if len(current_indents) == 0 or indent > current_indents[-1]: | ||||
|                 current_indents.append(indent) | ||||
|                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||
|                 lines[idx] = " " * new_indent + line[indent:] | ||||
|                 new_indent += len(_re_list.search(line).groups()[0]) + 1 | ||||
|                 new_indents.append(new_indent) | ||||
|             # Otherwise it's an existing level of list (current one, or previous one) | ||||
|             else: | ||||
|                 # Let's find the proper level of indentation | ||||
|                 level = len(current_indents) - 1 | ||||
|                 while level >= 0 and current_indents[level] != indent: | ||||
|                     level -= 1 | ||||
|                 current_indents = current_indents[: level + 1] | ||||
|                 new_indents = new_indents[:level] | ||||
|                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||
|                 lines[idx] = " " * new_indent + line[indent:] | ||||
|                 new_indent += len(_re_list.search(line).groups()[0]) + 1 | ||||
|                 new_indents.append(new_indent) | ||||
|  | ||||
|         # Line is an autodoc, we keep the indent for the list just after if there is one. | ||||
|         elif _re_autodoc.search(line) is not None: | ||||
|             indent = find_indent(line) | ||||
|             current_indents = [indent] | ||||
|             new_indents = [4] | ||||
|             lines[idx] = line.strip() | ||||
|  | ||||
|         # Deal with empty lines separately | ||||
|         elif is_empty_line(line): | ||||
|             lines[idx] = "" | ||||
|  | ||||
|         # Code blocks | ||||
|         elif line.lstrip().startswith("```"): | ||||
|             is_inside_code = not is_inside_code | ||||
|             if is_inside_code: | ||||
|                 code_indent = find_indent(line) | ||||
|             lines[idx] = line[code_indent:] | ||||
|         elif is_inside_code: | ||||
|             lines[idx] = line[code_indent:] | ||||
|  | ||||
|         else: | ||||
|             indent = find_indent(line) | ||||
|             if len(current_indents) > 0 and indent > current_indents[-1]: | ||||
|                 lines[idx] = " " * new_indents[-1] + line[indent:] | ||||
|             elif len(current_indents) > 0: | ||||
|                 # Let's find the proper level of indentation | ||||
|                 level = len(current_indents) - 1 | ||||
|                 while level >= 0 and current_indents[level] > indent: | ||||
|                     level -= 1 | ||||
|                 current_indents = current_indents[: level + 1] | ||||
|                 if level >= 0: | ||||
|                     if current_indents[level] < indent: | ||||
|                         new_indents = new_indents[: level + 1] | ||||
|                     else: | ||||
|                         new_indents = new_indents[:level] | ||||
|                     new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||
|                     lines[idx] = " " * new_indent + line[indent:] | ||||
|                     new_indents.append(new_indent) | ||||
|                 else: | ||||
|                     new_indents = [] | ||||
|                     lines[idx] = line[indent:] | ||||
|             else: | ||||
|                 lines[idx] = line[indent:] | ||||
|  | ||||
|     return "\n".join(lines) | ||||
|  | ||||
|  | ||||
| def base_rst_to_mdx(text, page_info, unindent=True): | ||||
|     """ | ||||
|     Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs. | ||||
|     """ | ||||
|     text = convert_rst_links(text, page_info) | ||||
|     text = convert_special_chars(text) | ||||
|     text = convert_rst_blocks(text, page_info) | ||||
|     # Convert * in lists to - to avoid the formatting conversion treat them as bold. | ||||
|     text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE) | ||||
|     text = convert_rst_formatting(text) | ||||
|     return remove_indent(text) if unindent else text | ||||
|  | ||||
|  | ||||
| def convert_rst_docstring_to_mdx(docstring, page_info): | ||||
|     """ | ||||
|     Convert a docstring written in rst to mdx. | ||||
|     """ | ||||
|     text = parse_rst_docstring(docstring) | ||||
|     return base_rst_to_mdx(text, page_info) | ||||
|  | ||||
|  | ||||
| def process_titles(lines): | ||||
|     """Converts rst titles to markdown titles.""" | ||||
|     title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ") | ||||
|     title_levels = {} | ||||
|     new_lines = [] | ||||
|     for line in lines: | ||||
|         if ( | ||||
|             len(new_lines) > 0 | ||||
|             and len(line) >= len(new_lines[-1]) | ||||
|             and len(set(line)) == 1 | ||||
|             and line[0] in title_chars | ||||
|             and line != "::" | ||||
|         ): | ||||
|             char = line[0] | ||||
|             level = title_levels.get(char, len(title_levels) + 1) | ||||
|             if level not in title_levels: | ||||
|                 title_levels[char] = level | ||||
|             new_lines[-1] = f"{'#' * level} {new_lines[-1]}" | ||||
|         else: | ||||
|             new_lines.append(line) | ||||
|     return new_lines | ||||
|  | ||||
|  | ||||
| # Matches lines with a pattern of a table new line in rst. | ||||
| _re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$") | ||||
| # Matches lines with a pattern of a table new line in rst, with a first column empty. | ||||
| _re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$") | ||||
| # Matches lines with a pattern of a first table line in rst. | ||||
| _re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$") | ||||
| # Re pattern that catches anchors of the type .. reference: | ||||
| _re_anchor_section = re.compile(r"^\.\.\s+_(\S+):") | ||||
|  | ||||
|  | ||||
| def split_pt_tf_code_blocks(text): | ||||
|     """ | ||||
|     Split PyTorch and TensorFlow specific block codes. | ||||
|     """ | ||||
|     lines = text.split("\n") | ||||
|     new_lines = [] | ||||
|     idx = 0 | ||||
|     while idx < len(lines): | ||||
|         if lines[idx].startswith("```"): | ||||
|             code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []} | ||||
|             is_pytorch = False | ||||
|             is_tensorflow = False | ||||
|             idx += 1 | ||||
|             while idx < len(lines) and lines[idx].strip() != "```": | ||||
|                 if "## PYTORCH CODE" in lines[idx]: | ||||
|                     is_pytorch = True | ||||
|                     is_tensorflow = False | ||||
|                 elif "## TENSORFLOW CODE" in lines[idx]: | ||||
|                     is_tensorflow = True | ||||
|                     is_pytorch = False | ||||
|                 elif is_pytorch: | ||||
|                     code_lines["pytorch"].append(lines[idx]) | ||||
|                 elif is_tensorflow: | ||||
|                     code_lines["tensorflow"].append(lines[idx]) | ||||
|                 else: | ||||
|                     code_lines["common"].append(lines[idx]) | ||||
|                 idx += 1 | ||||
|             if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0: | ||||
|                 block_lines = ["<frameworkcontent>", "<pt>"] | ||||
|                 block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"]) | ||||
|                 block_lines.extend(["```", "</pt>", "<tf>"]) | ||||
|                 block_lines.extend( | ||||
|                     code_lines["common"].copy() + code_lines["tensorflow"] | ||||
|                 ) | ||||
|                 block_lines.extend(["```", "</tf>", "</frameworkcontent>"]) | ||||
|                 new_lines.extend(block_lines) | ||||
|             else: | ||||
|                 block_lines = code_lines["common"] + ["```"] | ||||
|                 new_lines.extend(block_lines) | ||||
|             idx += 1 | ||||
|         else: | ||||
|             new_lines.append(lines[idx]) | ||||
|             idx += 1 | ||||
|     return "\n".join(new_lines) | ||||
|  | ||||
|  | ||||
| def convert_rst_to_mdx(rst_text, page_info, add_imports=True): | ||||
|     """ | ||||
|     Convert a document written in rst to mdx. | ||||
|     """ | ||||
|     lines = rst_text.split("\n") | ||||
|     lines = process_titles(lines) | ||||
|     if add_imports: | ||||
|         new_lines = [ | ||||
|             '<script lang="ts">', | ||||
|             '	import Tip from "$lib/Tip.svelte";', | ||||
|             '	import Youtube from "$lib/Youtube.svelte";', | ||||
|             '	import Docstring from "$lib/Docstring.svelte";', | ||||
|             '	import CodeBlock from "$lib/CodeBlock.svelte";', | ||||
|             '	import CodeBlockFw from "$lib/CodeBlockFw.svelte";', | ||||
|             '	import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";', | ||||
|             '	import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";', | ||||
|             '	import IconCopyLink from "$lib/IconCopyLink.svelte";', | ||||
|             '	import FrameworkContent from "$lib/FrameworkContent.svelte";', | ||||
|             '	import Markdown from "$lib/Markdown.svelte";', | ||||
|             '	import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";', | ||||
|             '	import Added from "$lib/Added.svelte";', | ||||
|             '	import Changed from "$lib/Changed.svelte";', | ||||
|             '	import Deprecated from "$lib/Deprecated.svelte";', | ||||
|             '	import PipelineIcon from "$lib/PipelineIcon.svelte";', | ||||
|             '	import PipelineTag from "$lib/PipelineTag.svelte";', | ||||
|             "	", | ||||
|             '	export let fw: "pt" | "tf"', | ||||
|             "</script>", | ||||
|             "<svelte:head>", | ||||
|             '<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >', | ||||
|             "</svelte:head>", | ||||
|             "", | ||||
|         ] | ||||
|     else: | ||||
|         new_lines = [] | ||||
|     for line in lines: | ||||
|         if _re_ignore_line_table.search(line) is not None: | ||||
|             continue | ||||
|         elif _re_ignore_line_table1.search(line) is not None: | ||||
|             continue | ||||
|         elif _re_sep_line_table.search(line) is not None: | ||||
|             line = line.replace("=", "-").replace("+", "|") | ||||
|         elif _re_anchor_section.search(line) is not None: | ||||
|             anchor_name = _re_anchor_section.search(line).groups()[0] | ||||
|             line = f"<a id='{anchor_name}'></a>" | ||||
|         new_lines.append(line) | ||||
|     text = "\n".join(new_lines) | ||||
|  | ||||
|     return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info)) | ||||
							
								
								
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,52 @@ | ||||
| from typing import Dict, Optional | ||||
|  | ||||
| from huggingface_hub import HfApi | ||||
| from huggingface_hub.hf_api import GitRefInfo | ||||
| from packaging.specifiers import SpecifierSet | ||||
| from packaging.version import InvalidVersion, Version | ||||
|  | ||||
|  | ||||
| def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: | ||||
|     """Get kernel versions that are available in the repository.""" | ||||
|     versions = {} | ||||
|     for tag in HfApi().list_repo_refs(repo_id).tags: | ||||
|         if not tag.name.startswith("v"): | ||||
|             continue | ||||
|         try: | ||||
|             versions[Version(tag.name[1:])] = tag | ||||
|         except InvalidVersion: | ||||
|             continue | ||||
|  | ||||
|     return versions | ||||
|  | ||||
|  | ||||
| def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo: | ||||
|     """ | ||||
|     Get the locks for a kernel with the given version spec. | ||||
|  | ||||
|     The version specifier can be any valid Python version specifier: | ||||
|     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers | ||||
|     """ | ||||
|     versions = _get_available_versions(repo_id) | ||||
|     requirement = SpecifierSet(version_spec) | ||||
|     accepted_versions = sorted(requirement.filter(versions.keys())) | ||||
|  | ||||
|     if len(accepted_versions) == 0: | ||||
|         raise ValueError( | ||||
|             f"No version of `{repo_id}` satisfies requirement: {version_spec}" | ||||
|         ) | ||||
|  | ||||
|     return versions[accepted_versions[-1]] | ||||
|  | ||||
|  | ||||
| def select_revision_or_version( | ||||
|     repo_id: str, revision: Optional[str], version: Optional[str] | ||||
| ) -> str: | ||||
|     if revision is not None and version is not None: | ||||
|         raise ValueError("Either a revision or a version must be specified, not both.") | ||||
|     elif revision is None and version is None: | ||||
|         revision = "main" | ||||
|     elif version is not None: | ||||
|         revision = resolve_version_spec_as_ref(repo_id, version).target_commit | ||||
|     assert revision is not None | ||||
|     return revision | ||||
							
								
								
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,142 @@ | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from huggingface_hub import snapshot_download | ||||
| from kernel_abi_check import ( | ||||
|     BinaryFormat, | ||||
|     IncompatibleAbi3Symbol, | ||||
|     IncompatibleMacOSVersion, | ||||
|     IncompatibleManylinuxSymbol, | ||||
|     MissingMacOSVersion, | ||||
|     NonAbi3Symbol, | ||||
|     ObjectFile, | ||||
| ) | ||||
|  | ||||
| from kernels.utils import CACHE_DIR | ||||
|  | ||||
|  | ||||
| def check_kernel( | ||||
|     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||
| ): | ||||
|     variants_path = ( | ||||
|         Path( | ||||
|             snapshot_download( | ||||
|                 repo_id, | ||||
|                 allow_patterns=["build/*"], | ||||
|                 cache_dir=CACHE_DIR, | ||||
|                 revision=revision, | ||||
|             ) | ||||
|         ) | ||||
|         / "build" | ||||
|     ) | ||||
|  | ||||
|     has_issues = False | ||||
|     for variant_path in variants_path.iterdir(): | ||||
|         if not variant_path.is_dir(): | ||||
|             print( | ||||
|                 f"⛔ `build/` must only contain directories, found: {variant_path.name}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             has_issues = True | ||||
|             continue | ||||
|  | ||||
|         print(f"Checking variant: {variant_path.name}", file=sys.stderr) | ||||
|  | ||||
|         indent = 2 | ||||
|  | ||||
|         for dylib_path in variant_path.rglob("*.so"): | ||||
|             print_with_indent( | ||||
|                 indent, | ||||
|                 f"Dynamic library {dylib_path.relative_to(variant_path)}:", | ||||
|             ) | ||||
|  | ||||
|             o = ObjectFile(dylib_path) | ||||
|             has_issues |= check_abi3(o, python_abi, indent + 2) | ||||
|  | ||||
|             # TODO: also check operating system | ||||
|             if o.format() == BinaryFormat.ELF: | ||||
|                 has_issues |= check_manylinux(o, manylinux, indent + 2) | ||||
|             elif o.format() == BinaryFormat.MACH_O: | ||||
|                 has_issues |= check_macos(o, macos, indent + 2) | ||||
|  | ||||
|     if has_issues: | ||||
|         sys.exit(1) | ||||
|  | ||||
|  | ||||
| def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_python_abi(python_abi) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:", | ||||
|         ) | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, IncompatibleAbi3Symbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}: {violation.version_added}", | ||||
|                 ) | ||||
|             elif isinstance(violation, NonAbi3Symbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_macos(macos) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found incompatibility with macOS {macos}:", | ||||
|         ) | ||||
|  | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, MissingMacOSVersion): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     "shared library does not contain macOS version", | ||||
|                 ) | ||||
|             elif isinstance(violation, IncompatibleMacOSVersion): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"shared library requires macOS {violation.version}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🍏 compatible with macOS {macos}") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_manylinux(manylinux) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found symbols that are incompatible with {manylinux}:", | ||||
|         ) | ||||
|  | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, IncompatibleManylinuxSymbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}_{violation.dep}: {violation.version}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🐧 {manylinux} compatible") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def print_with_indent(indent: int, message: str): | ||||
|     print(f"{' ' * indent}{message}", file=sys.stderr) | ||||
| @ -4,10 +4,15 @@ import json | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from huggingface_hub import create_repo, upload_folder | ||||
|  | ||||
| from kernels.compat import tomllib | ||||
| from kernels.lockfile import KernelLock, get_kernel_locks | ||||
| from kernels.utils import install_kernel, install_kernel_all_variants | ||||
|  | ||||
| from .doc import generate_readme_for_kernel | ||||
| from .wheel import build_variant_to_wheel | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser( | ||||
| @ -15,6 +20,31 @@ def main(): | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(required=True) | ||||
|  | ||||
|     check_parser = subparsers.add_parser("check", help="Check a kernel for compliance") | ||||
|     check_parser.add_argument("repo_id", type=str, help="The kernel repo ID") | ||||
|     check_parser.add_argument( | ||||
|         "--revision", | ||||
|         type=str, | ||||
|         default="main", | ||||
|         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", | ||||
|     ) | ||||
|     check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0") | ||||
|     check_parser.add_argument( | ||||
|         "--manylinux", type=str, help="Manylinux version", default="manylinux_2_28" | ||||
|     ) | ||||
|     check_parser.add_argument( | ||||
|         "--python-abi", type=str, help="Python ABI version", default="3.9" | ||||
|     ) | ||||
|     check_parser.set_defaults( | ||||
|         func=lambda args: check_kernel( | ||||
|             macos=args.macos, | ||||
|             manylinux=args.manylinux, | ||||
|             python_abi=args.python_abi, | ||||
|             repo_id=args.repo_id, | ||||
|             revision=args.revision, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     download_parser = subparsers.add_parser("download", help="Download locked kernels") | ||||
|     download_parser.add_argument( | ||||
|         "project_dir", | ||||
| @ -28,6 +58,29 @@ def main(): | ||||
|     ) | ||||
|     download_parser.set_defaults(func=download_kernels) | ||||
|  | ||||
|     upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub") | ||||
|     upload_parser.add_argument( | ||||
|         "kernel_dir", | ||||
|         type=Path, | ||||
|         help="Directory of the kernel build", | ||||
|     ) | ||||
|     upload_parser.add_argument( | ||||
|         "--repo_id", | ||||
|         type=str, | ||||
|         help="Repository ID to use to upload to the Hugging Face Hub", | ||||
|     ) | ||||
|     upload_parser.add_argument( | ||||
|         "--branch", | ||||
|         type=None, | ||||
|         help="If set, the upload will be made to a particular branch of the provided `repo_id`.", | ||||
|     ) | ||||
|     upload_parser.add_argument( | ||||
|         "--private", | ||||
|         action="store_true", | ||||
|         help="If the repository should be private.", | ||||
|     ) | ||||
|     upload_parser.set_defaults(func=upload_kernels) | ||||
|  | ||||
|     lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") | ||||
|     lock_parser.add_argument( | ||||
|         "project_dir", | ||||
| @ -36,6 +89,47 @@ def main(): | ||||
|     ) | ||||
|     lock_parser.set_defaults(func=lock_kernels) | ||||
|  | ||||
|     to_wheel_parser = subparsers.add_parser( | ||||
|         "to-wheel", help="Convert a kernel to a wheel file" | ||||
|     ) | ||||
|     to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID") | ||||
|     to_wheel_parser.add_argument("version", type=str, help="The kernel version") | ||||
|     to_wheel_parser.add_argument( | ||||
|         "--python-version", | ||||
|         type=str, | ||||
|         default="3.9", | ||||
|         help="The minimum Python version. Must match the Python version that the kernel was compiled for.", | ||||
|     ) | ||||
|     to_wheel_parser.add_argument( | ||||
|         "--manylinux-version", | ||||
|         type=str, | ||||
|         default="2.28", | ||||
|         help="The manylinux version. Must match the manylinux version that the kernel was compiled for.", | ||||
|     ) | ||||
|     to_wheel_parser.set_defaults(func=kernels_to_wheel) | ||||
|  | ||||
|     # Add generate-readme subcommand parser | ||||
|     generate_readme_parser = subparsers.add_parser( | ||||
|         "generate-readme", | ||||
|         help="Generate README snippets for a kernel's public functions", | ||||
|     ) | ||||
|     generate_readme_parser.add_argument( | ||||
|         "repo_id", | ||||
|         type=str, | ||||
|         help="The kernel repo ID (e.g., kernels-community/activation)", | ||||
|     ) | ||||
|     generate_readme_parser.add_argument( | ||||
|         "--revision", | ||||
|         type=str, | ||||
|         default="main", | ||||
|         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", | ||||
|     ) | ||||
|     generate_readme_parser.set_defaults( | ||||
|         func=lambda args: generate_readme_for_kernel( | ||||
|             repo_id=args.repo_id, revision=args.revision | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     args.func(args) | ||||
|  | ||||
| @ -77,6 +171,24 @@ def download_kernels(args): | ||||
|         sys.exit(1) | ||||
|  | ||||
|  | ||||
| def kernels_to_wheel(args): | ||||
|     variants_path = install_kernel_all_variants( | ||||
|         repo_id=args.repo_id, revision=f"v{args.version}" | ||||
|     ) | ||||
|     for variant_path in variants_path.iterdir(): | ||||
|         if not variant_path.is_dir(): | ||||
|             continue | ||||
|         wheel_path = build_variant_to_wheel( | ||||
|             manylinux_version=args.manylinux_version, | ||||
|             python_version=args.python_version, | ||||
|             repo_id=args.repo_id, | ||||
|             version=args.version, | ||||
|             variant_path=variant_path, | ||||
|             wheel_dir=Path("."), | ||||
|         ) | ||||
|         print(f"☸️ {wheel_path.name}", file=sys.stderr) | ||||
|  | ||||
|  | ||||
| def lock_kernels(args): | ||||
|     with open(args.project_dir / "pyproject.toml", "rb") as f: | ||||
|         data = tomllib.load(f) | ||||
| @ -91,8 +203,57 @@ def lock_kernels(args): | ||||
|         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||||
|  | ||||
|  | ||||
| def upload_kernels(args): | ||||
|     kernel_dir = Path(args.kernel_dir).resolve() | ||||
|     build_dir = kernel_dir / "build" | ||||
|     if not kernel_dir.is_dir(): | ||||
|         raise ValueError(f"{kernel_dir} is not a directory") | ||||
|     if not build_dir.is_dir(): | ||||
|         raise ValueError("Couldn't find `build` directory inside `kernel_dir`") | ||||
|  | ||||
|     repo_id = create_repo( | ||||
|         repo_id=args.repo_id, private=args.private, exist_ok=True | ||||
|     ).repo_id | ||||
|  | ||||
|     delete_patterns: set[str] = set() | ||||
|     for build_variant in build_dir.iterdir(): | ||||
|         if build_variant.is_dir(): | ||||
|             delete_patterns.add(f"{build_variant.name}/**") | ||||
|  | ||||
|     upload_folder( | ||||
|         repo_id=repo_id, | ||||
|         folder_path=build_dir, | ||||
|         revision=args.branch, | ||||
|         path_in_repo="build", | ||||
|         delete_patterns=list(delete_patterns), | ||||
|         commit_message="Build uploaded using `kernels`.", | ||||
|     ) | ||||
|     print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.") | ||||
|  | ||||
|  | ||||
| class _JSONEncoder(json.JSONEncoder): | ||||
|     def default(self, o): | ||||
|         if dataclasses.is_dataclass(o): | ||||
|             return dataclasses.asdict(o) | ||||
|         return super().default(o) | ||||
|  | ||||
|  | ||||
| def check_kernel( | ||||
|     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||
| ): | ||||
|     try: | ||||
|         import kernels.check | ||||
|     except ImportError: | ||||
|         print( | ||||
|             "`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check", | ||||
|             file=sys.stderr, | ||||
|         ) | ||||
|         sys.exit(1) | ||||
|  | ||||
|     kernels.check.check_kernel( | ||||
|         macos=macos, | ||||
|         manylinux=manylinux, | ||||
|         python_abi=python_abi, | ||||
|         repo_id=repo_id, | ||||
|         revision=revision, | ||||
|     ) | ||||
|  | ||||
							
								
								
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,242 @@ | ||||
| import inspect | ||||
| import re | ||||
| import sys | ||||
| from types import ModuleType | ||||
|  | ||||
| import yaml | ||||
|  | ||||
| from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx | ||||
| from .utils import get_kernel | ||||
|  | ||||
| _RE_PARAMETERS = re.compile( | ||||
|     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL | ||||
| ) | ||||
| _RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL) | ||||
| _RE_RETURNTYPE = re.compile( | ||||
|     r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL | ||||
| ) | ||||
|  | ||||
|  | ||||
| def _extract_description_before_tags(docstring_mdx: str) -> str: | ||||
|     """Extract the description part of a docstring before any tags.""" | ||||
|     params_pos = docstring_mdx.find("<parameters>") | ||||
|     returns_pos = docstring_mdx.find("<returns>") | ||||
|     returntype_pos = docstring_mdx.find("<returntype>") | ||||
|     positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1] | ||||
|  | ||||
|     if positions: | ||||
|         first_tag_pos = min(positions) | ||||
|         return docstring_mdx[:first_tag_pos].strip() | ||||
|     else: | ||||
|         return docstring_mdx.strip() | ||||
|  | ||||
|  | ||||
| def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None: | ||||
|     """Print the parameters section from a docstring.""" | ||||
|     matches = _RE_PARAMETERS.findall(docstring_mdx) | ||||
|     if matches: | ||||
|         header = "#" * header_level | ||||
|         print(f"\n{header} Parameters") | ||||
|         for match in matches: | ||||
|             print(f"\n{match[0].strip()}") | ||||
|  | ||||
|  | ||||
| def _print_returns_section( | ||||
|     docstring_mdx: str, *, context_name: str, header_level: int | ||||
| ) -> None: | ||||
|     """Print the returns section from a docstring.""" | ||||
|     return_matches = _RE_RETURNS.findall(docstring_mdx) | ||||
|     returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx) | ||||
|  | ||||
|     if return_matches or returntype_matches: | ||||
|         header = "#" * header_level | ||||
|         print(f"\n{header} Returns") | ||||
|  | ||||
|         if returntype_matches: | ||||
|             if len(returntype_matches) > 1: | ||||
|                 raise ValueError( | ||||
|                     f"More than one <returntype> tag found in docstring for {context_name}" | ||||
|                 ) | ||||
|             print(f"\n**Type**: {returntype_matches[0][0].strip()}") | ||||
|  | ||||
|         if return_matches: | ||||
|             for match in return_matches: | ||||
|                 print(f"\n{match[0].strip()}") | ||||
|  | ||||
|  | ||||
| def _get_docstring(obj, use_dict_check: bool = False) -> str: | ||||
|     """Get docstring from an object, with fallback to default message.""" | ||||
|     # Check whether the class/method itself has docs and not just | ||||
|     # the superclass. | ||||
|     if use_dict_check: | ||||
|         has_doc = obj.__dict__.get("__doc__", None) is not None | ||||
|     else: | ||||
|         has_doc = getattr(obj, "__doc__", None) is not None | ||||
|  | ||||
|     # We use inspect.getdoc because it does normalization. | ||||
|     doc = inspect.getdoc(obj) | ||||
|  | ||||
|     return doc if has_doc and doc is not None else "No documentation available." | ||||
|  | ||||
|  | ||||
| def _process_and_print_docstring( | ||||
|     docstring: str, *, kernel_name: str, context_name: str, header_level: int | ||||
| ) -> None: | ||||
|     """Convert docstring to MDX and print description, parameters, and returns sections.""" | ||||
|     docstring_mdx = convert_rst_docstring_to_mdx( | ||||
|         docstring, page_info={"package_name": kernel_name} | ||||
|     ) | ||||
|  | ||||
|     # Print the description | ||||
|     description = _extract_description_before_tags(docstring_mdx) | ||||
|     print(f"\n{description}") | ||||
|  | ||||
|     # Print parameters and returns sections | ||||
|     _print_parameters_section(docstring_mdx, header_level=header_level) | ||||
|     _print_returns_section( | ||||
|         docstring_mdx, context_name=context_name, header_level=header_level | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None: | ||||
|     kernel_module = get_kernel(repo_id=repo_id, revision=revision) | ||||
|     kernel_name = repo_id.split("/")[-1].replace("-", "_") | ||||
|  | ||||
|     generate_metadata(kernel_module) | ||||
|     generate_kernel_doc(kernel_module, kernel_name) | ||||
|     generate_function_doc(kernel_module, kernel_name) | ||||
|     generate_layers_doc(kernel_module, kernel_name) | ||||
|  | ||||
|  | ||||
| def generate_metadata(module: ModuleType) -> None: | ||||
|     metadata = getattr(module, "__kernel_metadata__", {}) | ||||
|     if "tags" not in metadata: | ||||
|         metadata["tags"] = ["kernel"] | ||||
|     else: | ||||
|         if "kernel" not in metadata["tags"]: | ||||
|             metadata["tags"].append("kernel") | ||||
|  | ||||
|     print("---") | ||||
|     print(yaml.dump(metadata), end="") | ||||
|     print("---") | ||||
|  | ||||
|  | ||||
| def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None: | ||||
|     docstring = module.__doc__.strip() if module.__doc__ is not None else None | ||||
|     if docstring: | ||||
|         title, rest = docstring.split("\n", 1) | ||||
|         print(f"# {title.strip()}") | ||||
|         print( | ||||
|             f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||
|     print("\n## Functions") | ||||
|  | ||||
|     # Track if we found any functions | ||||
|     found_functions = False | ||||
|  | ||||
|     for name, func in inspect.getmembers(kernel_module, inspect.isfunction): | ||||
|         # Do not include imported functions. | ||||
|         if func.__module__ != kernel_module.__name__: | ||||
|             continue | ||||
|  | ||||
|         # Exclude private functions. | ||||
|         if name.startswith("_"): | ||||
|             continue | ||||
|  | ||||
|         found_functions = True | ||||
|  | ||||
|         try: | ||||
|             sig = inspect.signature(func) | ||||
|             docstring = _get_docstring(func) | ||||
|         except ValueError: | ||||
|             print( | ||||
|                 f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             continue | ||||
|  | ||||
|         print(f"\n### Function `{name}`") | ||||
|         print(f"\n`{sig}`") | ||||
|  | ||||
|         _process_and_print_docstring( | ||||
|             docstring, kernel_name=kernel_name, context_name=name, header_level=3 | ||||
|         ) | ||||
|  | ||||
|     if not found_functions: | ||||
|         print("\nNo public top-level functions.") | ||||
|  | ||||
|  | ||||
| def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||
|     # Check if layers module is available | ||||
|     layers_module = getattr(kernel_module, "layers", None) | ||||
|     if layers_module is None: | ||||
|         return | ||||
|  | ||||
|     print("\n## Layers") | ||||
|  | ||||
|     # Track if we found any classes | ||||
|     found_classes = False | ||||
|  | ||||
|     for class_name, cls in inspect.getmembers(layers_module, inspect.isclass): | ||||
|         # Exclude classes that were imported. | ||||
|         if cls.__module__ != layers_module.__name__: | ||||
|             continue | ||||
|  | ||||
|         found_classes = True | ||||
|  | ||||
|         try: | ||||
|             # Get docstring, but not from superclasses. | ||||
|             class_docstring = _get_docstring(cls, use_dict_check=True) | ||||
|         except Exception: | ||||
|             print( | ||||
|                 f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             continue | ||||
|  | ||||
|         print(f"\n### Class `{class_name}`") | ||||
|  | ||||
|         # Always print class description (helper handles conversion and formatting) | ||||
|         class_docstring_mdx = convert_rst_docstring_to_mdx( | ||||
|             class_docstring, page_info={"package_name": kernel_name} | ||||
|         ) | ||||
|         description = _extract_description_before_tags(class_docstring_mdx) | ||||
|         print(f"\n{description}") | ||||
|  | ||||
|         # Document methods | ||||
|         print("\n#### Methods") | ||||
|  | ||||
|         for method_name, method in inspect.getmembers(cls, inspect.isfunction): | ||||
|             # Note: also skip __init__, since extension layers cannot have a constructor. | ||||
|             if method_name.startswith("_"): | ||||
|                 continue | ||||
|  | ||||
|             # Skip methods from superclasses. | ||||
|             if method_name not in cls.__dict__: | ||||
|                 continue | ||||
|  | ||||
|             try: | ||||
|                 sig = inspect.signature(method) | ||||
|                 method_docstring = _get_docstring(method) | ||||
|             except ValueError: | ||||
|                 print( | ||||
|                     f"Warning: Could not retrieve signature for {method_name} in {class_name}", | ||||
|                     file=sys.stderr, | ||||
|                 ) | ||||
|                 continue | ||||
|  | ||||
|             print(f"\n##### Method `{method_name}`") | ||||
|             print(f"\n`{sig}`") | ||||
|  | ||||
|             _process_and_print_docstring( | ||||
|                 method_docstring, | ||||
|                 kernel_name=kernel_name, | ||||
|                 context_name=method_name, | ||||
|                 header_level=6, | ||||
|             ) | ||||
|  | ||||
|     if not found_classes: | ||||
|         print("\nNo layers defined.") | ||||
							
								
								
									
										1205
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
							
						
						
									
										1205
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -4,10 +4,8 @@ from pathlib import Path | ||||
| from typing import Dict, List, Tuple | ||||
|  | ||||
| from huggingface_hub import HfApi | ||||
| from huggingface_hub.hf_api import GitRefInfo | ||||
| from packaging.specifiers import SpecifierSet | ||||
| from packaging.version import InvalidVersion, Version | ||||
|  | ||||
| from kernels._versions import resolve_version_spec_as_ref | ||||
| from kernels.compat import tomllib | ||||
|  | ||||
|  | ||||
| @ -31,20 +29,6 @@ class KernelLock: | ||||
|         return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) | ||||
|  | ||||
|  | ||||
| def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: | ||||
|     """Get kernel versions that are available in the repository.""" | ||||
|     versions = {} | ||||
|     for tag in HfApi().list_repo_refs(repo_id).tags: | ||||
|         if not tag.name.startswith("v"): | ||||
|             continue | ||||
|         try: | ||||
|             versions[Version(tag.name[1:])] = tag | ||||
|         except InvalidVersion: | ||||
|             continue | ||||
|  | ||||
|     return versions | ||||
|  | ||||
|  | ||||
| def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | ||||
|     """ | ||||
|     Get the locks for a kernel with the given version spec. | ||||
| @ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | ||||
|     The version specifier can be any valid Python version specifier: | ||||
|     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers | ||||
|     """ | ||||
|     versions = _get_available_versions(repo_id) | ||||
|     requirement = SpecifierSet(version_spec) | ||||
|     accepted_versions = sorted(requirement.filter(versions.keys())) | ||||
|  | ||||
|     if len(accepted_versions) == 0: | ||||
|         raise ValueError( | ||||
|             f"No version of `{repo_id}` satisfies requirement: {version_spec}" | ||||
|         ) | ||||
|  | ||||
|     tag_for_newest = versions[accepted_versions[-1]] | ||||
|     tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec) | ||||
|  | ||||
|     r = HfApi().repo_info( | ||||
|         repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True | ||||
|  | ||||
| @ -4,6 +4,7 @@ import importlib | ||||
| import importlib.metadata | ||||
| import inspect | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| import platform | ||||
| import sys | ||||
| @ -12,29 +13,71 @@ from pathlib import Path | ||||
| from types import ModuleType | ||||
| from typing import Dict, List, Optional, Tuple | ||||
|  | ||||
| from huggingface_hub import snapshot_download | ||||
| from huggingface_hub import file_exists, snapshot_download | ||||
| from packaging.version import parse | ||||
|  | ||||
| from kernels._versions import select_revision_or_version | ||||
| from kernels.lockfile import KernelLock, VariantLock | ||||
|  | ||||
| CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None) | ||||
|  | ||||
| def _get_cache_dir() -> Optional[str]: | ||||
|     """Returns the kernels cache directory.""" | ||||
|     cache_dir = os.environ.get("HF_KERNELS_CACHE", None) | ||||
|     if cache_dir is not None: | ||||
|         logging.warning( | ||||
|             "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead" | ||||
|         ) | ||||
|         return cache_dir | ||||
|  | ||||
|     return os.environ.get("KERNELS_CACHE", None) | ||||
|  | ||||
|  | ||||
| CACHE_DIR: Optional[str] = _get_cache_dir() | ||||
|  | ||||
|  | ||||
| def _get_privateuse_backend_name() -> Optional[str]: | ||||
|     import torch | ||||
|  | ||||
|     if hasattr(torch._C, "_get_privateuse1_backend_name"): | ||||
|         return torch._C._get_privateuse1_backend_name() | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def build_variant() -> str: | ||||
|     import torch | ||||
|  | ||||
|     if torch.version.cuda is None: | ||||
|     if torch.version.cuda is not None: | ||||
|         cuda_version = parse(torch.version.cuda) | ||||
|         compute_framework = f"cu{cuda_version.major}{cuda_version.minor}" | ||||
|     elif torch.version.hip is not None: | ||||
|         rocm_version = parse(torch.version.hip.split("-")[0]) | ||||
|         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||
|     elif torch.backends.mps.is_available(): | ||||
|         compute_framework = "metal" | ||||
|     elif torch.version.xpu is not None: | ||||
|         version = torch.version.xpu | ||||
|         compute_framework = f"xpu{version[0:4]}{version[5:6]}" | ||||
|     elif _get_privateuse_backend_name() == "npu": | ||||
|         from torch_npu.utils.collect_env import get_cann_version  # type: ignore[import-not-found] | ||||
|  | ||||
|         cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] | ||||
|         compute_framework = f"cann{cann_major}{cann_minor}" | ||||
|     else: | ||||
|         raise AssertionError( | ||||
|             "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled." | ||||
|             "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." | ||||
|         ) | ||||
|  | ||||
|     torch_version = parse(torch.__version__) | ||||
|     cuda_version = parse(torch.version.cuda) | ||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||
|     cpu = platform.machine() | ||||
|     os = platform.system().lower() | ||||
|  | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" | ||||
|     if os == "darwin": | ||||
|         cpu = "aarch64" if cpu == "arm64" else cpu | ||||
|         return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" | ||||
|  | ||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||
|  | ||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" | ||||
|  | ||||
|  | ||||
| def universal_build_variant() -> str: | ||||
| @ -69,7 +112,20 @@ def install_kernel( | ||||
|     """ | ||||
|     Download a kernel for the current environment to the cache. | ||||
|  | ||||
|     The output path is validated againt `hash` when set. | ||||
|     The output path is validated against the hashes in `variant_locks` when provided. | ||||
|  | ||||
|     Args: | ||||
|         repo_id (`str`): | ||||
|             The Hub repository containing the kernel. | ||||
|         revision (`str`): | ||||
|             The specific revision (branch, tag, or commit) to download. | ||||
|         local_files_only (`bool`, *optional*, defaults to `False`): | ||||
|             Whether to only use local files and not download from the Hub. | ||||
|         variant_locks (`Dict[str, VariantLock]`, *optional*): | ||||
|             Optional dictionary of variant locks for validation. | ||||
|  | ||||
|     Returns: | ||||
|         `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory. | ||||
|     """ | ||||
|     package_name = package_name_from_repo_id(repo_id) | ||||
|     variant = build_variant() | ||||
| @ -84,6 +140,23 @@ def install_kernel( | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     try: | ||||
|         return _load_kernel_from_path(repo_path, package_name, variant_locks) | ||||
|     except FileNotFoundError: | ||||
|         # Redo with more specific error message. | ||||
|         raise FileNotFoundError( | ||||
|             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def _load_kernel_from_path( | ||||
|     repo_path: Path, | ||||
|     package_name: str, | ||||
|     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||
| ) -> Tuple[str, Path]: | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
|  | ||||
|     variant_path = repo_path / "build" / variant | ||||
|     universal_variant_path = repo_path / "build" / universal_variant | ||||
|  | ||||
| @ -102,7 +175,7 @@ def install_kernel( | ||||
|  | ||||
|     if not os.path.exists(module_init_path): | ||||
|         raise FileNotFoundError( | ||||
|             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" | ||||
|             f"Kernel at path `{repo_path}` does not have build: {variant}" | ||||
|         ) | ||||
|  | ||||
|     return package_name, variant_path | ||||
| @ -139,17 +212,128 @@ def install_kernel_all_variants( | ||||
|     return repo_path / "build" | ||||
|  | ||||
|  | ||||
| def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | ||||
| def get_kernel( | ||||
|     repo_id: str, revision: Optional[str] = None, version: Optional[str] = None | ||||
| ) -> ModuleType: | ||||
|     """ | ||||
|     Load a kernel from the kernel hub. | ||||
|  | ||||
|     This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before) | ||||
|     and then loads the kernel. | ||||
|  | ||||
|     Args: | ||||
|         repo_id (`str`): | ||||
|             The Hub repository containing the kernel. | ||||
|         revision (`str`, *optional*, defaults to `"main"`): | ||||
|             The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. | ||||
|         version (`str`, *optional*): | ||||
|             The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. | ||||
|             Cannot be used together with `revision`. | ||||
|  | ||||
|     Returns: | ||||
|         `ModuleType`: The imported kernel module. | ||||
|  | ||||
|     Example: | ||||
|         ```python | ||||
|         import torch | ||||
|         from kernels import get_kernel | ||||
|  | ||||
|         activation = get_kernel("kernels-community/activation") | ||||
|         x = torch.randn(10, 20, device="cuda") | ||||
|         out = torch.empty_like(x) | ||||
|         result = activation.silu_and_mul(out, x) | ||||
|         ``` | ||||
|     """ | ||||
|     revision = select_revision_or_version(repo_id, revision, version) | ||||
|     package_name, package_path = install_kernel(repo_id, revision=revision) | ||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||
|  | ||||
|  | ||||
| def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: | ||||
|     """ | ||||
|     Import a kernel from a local kernel repository path. | ||||
|  | ||||
|     Args: | ||||
|         repo_path (`Path`): | ||||
|             The local path to the kernel repository. | ||||
|         package_name (`str`): | ||||
|             The name of the package to import from the repository. | ||||
|  | ||||
|     Returns: | ||||
|         `ModuleType`: The imported kernel module. | ||||
|     """ | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
|  | ||||
|     # Presume we were given the top level path of the kernel repository. | ||||
|     for base_path in [repo_path, repo_path / "build"]: | ||||
|         # Prefer the universal variant if it exists. | ||||
|         for v in [universal_variant, variant]: | ||||
|             package_path = base_path / v / package_name / "__init__.py" | ||||
|             if package_path.exists(): | ||||
|                 return import_from_path(package_name, package_path) | ||||
|  | ||||
|     # If we didn't find the package in the repo we may have a explicit | ||||
|     # package path. | ||||
|     package_path = repo_path / package_name / "__init__.py" | ||||
|     if package_path.exists(): | ||||
|         return import_from_path(package_name, package_path) | ||||
|  | ||||
|     raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}") | ||||
|  | ||||
|  | ||||
| def has_kernel( | ||||
|     repo_id: str, revision: Optional[str] = None, version: Optional[str] = None | ||||
| ) -> bool: | ||||
|     """ | ||||
|     Check whether a kernel build exists for the current environment (Torch version and compute framework). | ||||
|  | ||||
|     Args: | ||||
|         repo_id (`str`): | ||||
|             The Hub repository containing the kernel. | ||||
|         revision (`str`, *optional*, defaults to `"main"`): | ||||
|             The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. | ||||
|         version (`str`, *optional*): | ||||
|             The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. | ||||
|             Cannot be used together with `revision`. | ||||
|  | ||||
|     Returns: | ||||
|         `bool`: `True` if a kernel is available for the current environment. | ||||
|     """ | ||||
|     revision = select_revision_or_version(repo_id, revision, version) | ||||
|  | ||||
|     package_name = package_name_from_repo_id(repo_id) | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
|  | ||||
|     if file_exists( | ||||
|         repo_id, | ||||
|         revision=revision, | ||||
|         filename=f"build/{universal_variant}/{package_name}/__init__.py", | ||||
|     ): | ||||
|         return True | ||||
|  | ||||
|     return file_exists( | ||||
|         repo_id, | ||||
|         revision=revision, | ||||
|         filename=f"build/{variant}/{package_name}/__init__.py", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||
|     """ | ||||
|     Get a pre-downloaded, locked kernel. | ||||
|  | ||||
|     If `lockfile` is not specified, the lockfile will be loaded from the | ||||
|     caller's package metadata. | ||||
|     If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata. | ||||
|  | ||||
|     Args: | ||||
|         repo_id (`str`): | ||||
|             The Hub repository containing the kernel. | ||||
|         lockfile (`Path`, *optional*): | ||||
|             Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata. | ||||
|  | ||||
|     Returns: | ||||
|         `ModuleType`: The imported kernel module. | ||||
|     """ | ||||
|     if lockfile is None: | ||||
|         locked_sha = _get_caller_locked_kernel(repo_id) | ||||
| @ -194,7 +378,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||
|  | ||||
|  | ||||
| def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | ||||
|     """Get a kernel using a lock file.""" | ||||
|     """ | ||||
|     Get a kernel using a lock file. | ||||
|  | ||||
|     Args: | ||||
|         repo_id (`str`): | ||||
|             The Hub repository containing the kernel. | ||||
|         local_files_only (`bool`, *optional*, defaults to `False`): | ||||
|             Whether to only use local files and not download from the Hub. | ||||
|  | ||||
|     Returns: | ||||
|         `ModuleType`: The imported kernel module. | ||||
|     """ | ||||
|     locked_sha = _get_caller_locked_kernel(repo_id) | ||||
|  | ||||
|     if locked_sha is None: | ||||
|  | ||||
							
								
								
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,186 @@ | ||||
| import email.policy | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
| from email.message import Message | ||||
| from importlib.metadata import PackageNotFoundError, version | ||||
| from pathlib import Path | ||||
| from typing import Optional | ||||
|  | ||||
| try: | ||||
|     KERNELS_VERSION = version("kernels") | ||||
| except PackageNotFoundError: | ||||
|     KERNELS_VERSION = "unknown" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class Metadata: | ||||
|     name: str | ||||
|     version: str | ||||
|     cuda_version: Optional[str] | ||||
|     cxx_abi_version: Optional[str] | ||||
|     torch_version: Optional[str] | ||||
|     os: Optional[str] | ||||
|     platform: Optional[str] | ||||
|  | ||||
|     @property | ||||
|     def is_universal(self) -> bool: | ||||
|         return self.platform is None | ||||
|  | ||||
|  | ||||
| def build_variant_to_wheel( | ||||
|     repo_id: str, | ||||
|     *, | ||||
|     version: str, | ||||
|     variant_path: Path, | ||||
|     wheel_dir: Path, | ||||
|     manylinux_version: str = "2.28", | ||||
|     python_version: str = "3.9", | ||||
| ) -> Path: | ||||
|     """ | ||||
|     Create a wheel file from the variant path. | ||||
|     """ | ||||
|     name = repo_id.split("/")[-1].replace("_", "-") | ||||
|     metadata = extract_metadata(name, version, variant_path) | ||||
|     return build_wheel( | ||||
|         metadata, | ||||
|         variant_path=variant_path, | ||||
|         wheel_dir=wheel_dir, | ||||
|         manylinux_version=manylinux_version, | ||||
|         python_version=python_version, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata: | ||||
|     """ | ||||
|     Extract metadata from the variant path. | ||||
|     """ | ||||
|     if variant_path.name == "torch-universal": | ||||
|         return Metadata( | ||||
|             name=name, | ||||
|             version=version, | ||||
|             cuda_version=None, | ||||
|             cxx_abi_version=None, | ||||
|             torch_version=None, | ||||
|             os=None, | ||||
|             platform=None, | ||||
|         ) | ||||
|  | ||||
|     if not variant_path.name.startswith("torch"): | ||||
|         raise ValueError("Currently only conversion of Torch kernels is supported.") | ||||
|  | ||||
|     variant_parts = variant_path.name.removeprefix("torch").split("-") | ||||
|     if len(variant_parts) != 5: | ||||
|         raise ValueError(f"Invalid variant name: {variant_path.name}") | ||||
|  | ||||
|     torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}" | ||||
|     cpp_abi_version = variant_parts[1].removeprefix("cxx") | ||||
|     cuda_version = variant_parts[2].removeprefix("cu") | ||||
|     platform = variant_parts[3].replace("-", "_") | ||||
|     os = variant_parts[4] | ||||
|  | ||||
|     return Metadata( | ||||
|         name=name, | ||||
|         version=version, | ||||
|         cuda_version=cuda_version, | ||||
|         cxx_abi_version=cpp_abi_version, | ||||
|         torch_version=torch_version, | ||||
|         os=os, | ||||
|         platform=platform, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def build_wheel( | ||||
|     metadata: Metadata, | ||||
|     *, | ||||
|     variant_path: Path, | ||||
|     wheel_dir: Path, | ||||
|     manylinux_version: str = "2.28", | ||||
|     python_version: str = "3.9", | ||||
| ) -> Path: | ||||
|     """ | ||||
|     Build the wheel file. | ||||
|     """ | ||||
|     try: | ||||
|         from wheel.wheelfile import WheelFile  # type: ignore | ||||
|     except ImportError: | ||||
|         raise ImportError( | ||||
|             "The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`" | ||||
|         ) | ||||
|  | ||||
|     name = metadata.name.replace("-", "_") | ||||
|     python_version_flat = python_version.replace(".", "") | ||||
|  | ||||
|     if metadata.is_universal: | ||||
|         python_tag = f"py{python_version_flat}" | ||||
|         abi_tag = "none" | ||||
|         platform_tag = "any" | ||||
|         wheel_filename = ( | ||||
|             f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl" | ||||
|         ) | ||||
|         dist_info_dir_name = f"{name}-{metadata.version}.dist-info" | ||||
|         root_is_purelib = "true" | ||||
|         requires_dist_torch = "torch" | ||||
|     else: | ||||
|         python_tag = f"cp{python_version_flat}" | ||||
|         abi_tag = "abi3" | ||||
|  | ||||
|         if ( | ||||
|             metadata.torch_version is None | ||||
|             or metadata.cuda_version is None | ||||
|             or metadata.cxx_abi_version is None | ||||
|             or metadata.os is None | ||||
|             or metadata.platform is None | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 "Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels." | ||||
|             ) | ||||
|  | ||||
|         local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}" | ||||
|  | ||||
|         if metadata.os == "linux": | ||||
|             platform_tag = ( | ||||
|                 f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}" | ||||
|             ) | ||||
|         else: | ||||
|             platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}" | ||||
|  | ||||
|         wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" | ||||
|         dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info" | ||||
|         root_is_purelib = "false" | ||||
|         requires_dist_torch = f"torch=={metadata.torch_version}.*" | ||||
|  | ||||
|     wheel_path = wheel_dir / wheel_filename | ||||
|  | ||||
|     wheel_msg = Message(email.policy.compat32) | ||||
|     wheel_msg.add_header("Wheel-Version", "1.0") | ||||
|     wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})") | ||||
|     wheel_msg.add_header("Root-Is-Purelib", root_is_purelib) | ||||
|     wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}") | ||||
|  | ||||
|     metadata_msg = Message(email.policy.compat32) | ||||
|     metadata_msg.add_header("Metadata-Version", "2.1") | ||||
|     metadata_msg.add_header("Name", name) | ||||
|     metadata_msg.add_header("Version", metadata.version) | ||||
|     metadata_msg.add_header("Summary", f"{name} kernel") | ||||
|     metadata_msg.add_header("Requires-Python", ">=3.9") | ||||
|     metadata_msg.add_header("Requires-Dist", requires_dist_torch) | ||||
|  | ||||
|     source_pkg_dir = variant_path / name | ||||
|  | ||||
|     with WheelFile(wheel_path, "w") as wheel_file: | ||||
|         for root, dirnames, filenames in os.walk(source_pkg_dir): | ||||
|             for filename in filenames: | ||||
|                 if filename.endswith(".pyc"): | ||||
|                     continue | ||||
|  | ||||
|                 abs_filepath = os.path.join(root, filename) | ||||
|                 entry_name = os.path.relpath(abs_filepath, variant_path) | ||||
|                 wheel_file.write(abs_filepath, entry_name) | ||||
|  | ||||
|         wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL") | ||||
|         wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8")) | ||||
|  | ||||
|         metadata_path = os.path.join(dist_info_dir_name, "METADATA") | ||||
|         wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8")) | ||||
|  | ||||
|     return wheel_path | ||||
							
								
								
									
										46
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | ||||
| import sys | ||||
|  | ||||
| import pytest | ||||
| import torch | ||||
|  | ||||
| from kernels.utils import _get_privateuse_backend_name | ||||
|  | ||||
| has_cuda = ( | ||||
|     hasattr(torch.version, "cuda") | ||||
|     and torch.version.cuda is not None | ||||
|     and torch.cuda.device_count() > 0 | ||||
| ) | ||||
| has_rocm = ( | ||||
|     hasattr(torch.version, "hip") | ||||
|     and torch.version.hip is not None | ||||
|     and torch.cuda.device_count() > 0 | ||||
| ) | ||||
| has_xpu = ( | ||||
|     hasattr(torch.version, "xpu") | ||||
|     and torch.version.xpu is not None | ||||
|     and torch.xpu.device_count() > 0 | ||||
| ) | ||||
| has_npu = _get_privateuse_backend_name() == "npu" | ||||
|  | ||||
|  | ||||
| def pytest_addoption(parser): | ||||
|     parser.addoption( | ||||
|         "--token", | ||||
|         action="store_true", | ||||
|         help="run tests that require a token with write permissions", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def pytest_runtest_setup(item): | ||||
|     if "cuda_only" in item.keywords and not has_cuda: | ||||
|         pytest.skip("skipping CUDA-only test on host without CUDA") | ||||
|     if "rocm_only" in item.keywords and not has_rocm: | ||||
|         pytest.skip("skipping ROCm-only test on host without ROCm") | ||||
|     if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): | ||||
|         pytest.skip("skipping macOS-only test on non-macOS platform") | ||||
|     if "xpu_only" in item.keywords and not has_xpu: | ||||
|         pytest.skip("skipping XPU-only test on host without XPU") | ||||
|     if "npu_only" in item.keywords and not has_npu: | ||||
|         pytest.skip("skipping NPU-only test on host without NPU") | ||||
|     if "token" in item.keywords and not item.config.getoption("--token"): | ||||
|         pytest.skip("need --token option to run this test") | ||||
| @ -1,54 +1,82 @@ | ||||
| [ | ||||
|   { | ||||
|     "repo_id": "kernels-community/activation", | ||||
|     "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0", | ||||
|     "sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538", | ||||
|     "variants": { | ||||
|       "torch25-cxx11-cu118-x86_64-linux": { | ||||
|         "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d", | ||||
|         "hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch25-cxx11-cu121-x86_64-linux": { | ||||
|         "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58", | ||||
|         "hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch25-cxx11-cu124-x86_64-linux": { | ||||
|         "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a", | ||||
|         "hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch25-cxx98-cu118-x86_64-linux": { | ||||
|         "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0", | ||||
|         "hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch25-cxx98-cu121-x86_64-linux": { | ||||
|         "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8", | ||||
|         "hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch25-cxx98-cu124-x86_64-linux": { | ||||
|         "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc", | ||||
|         "hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx11-cu118-x86_64-linux": { | ||||
|         "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b", | ||||
|         "hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx11-cu124-x86_64-linux": { | ||||
|         "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005", | ||||
|         "hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx11-cu126-aarch64-linux": { | ||||
|         "hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx11-cu126-x86_64-linux": { | ||||
|         "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc", | ||||
|         "hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx98-cu118-x86_64-linux": { | ||||
|         "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f", | ||||
|         "hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx98-cu124-x86_64-linux": { | ||||
|         "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf", | ||||
|         "hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx98-cu126-aarch64-linux": { | ||||
|         "hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch26-cxx98-cu126-x86_64-linux": { | ||||
|         "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a", | ||||
|         "hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch27-cxx11-cu118-x86_64-linux": { | ||||
|         "hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch27-cxx11-cu126-aarch64-linux": { | ||||
|         "hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch27-cxx11-cu126-x86_64-linux": { | ||||
|         "hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch27-cxx11-cu128-aarch64-linux": { | ||||
|         "hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       }, | ||||
|       "torch27-cxx11-cu128-x86_64-linux": { | ||||
|         "hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       } | ||||
|     } | ||||
|  | ||||
							
								
								
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | ||||
| [ | ||||
|   { | ||||
|     "repo_id": "kernels-test/versions", | ||||
|     "sha": "dc142fd6c9920c993d32be6358b78957c58681c3", | ||||
|     "variants": { | ||||
|       "torch-universal": { | ||||
|         "hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c", | ||||
|         "hash_type": "git_lfs_concat" | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| ] | ||||
							
								
								
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | ||||
| [tool.kernels.dependencies] | ||||
| "kernels-test/versions" = ">=0.1.0,<0.2.0" | ||||
| @ -1,7 +1,7 @@ | ||||
| import pytest | ||||
| import torch | ||||
|  | ||||
| from kernels import get_kernel | ||||
| from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| @ -9,6 +9,25 @@ def kernel(): | ||||
|     return get_kernel("kernels-community/activation") | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def local_kernel_path(): | ||||
|     package_name, path = install_kernel("kernels-community/activation", "main") | ||||
|     # Path is the build variant path (build/torch-<...>), so the grandparent | ||||
|     # is the kernel repository path. | ||||
|     return package_name, path | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def local_kernel(local_kernel_path): | ||||
|     package_name, path = local_kernel_path | ||||
|     return get_local_kernel(path.parent.parent, package_name) | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def metal_kernel(): | ||||
|     return get_kernel("kernels-test/relu-metal") | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def universal_kernel(): | ||||
|     return get_kernel("kernels-community/triton-scaled-mm") | ||||
| @ -21,6 +40,7 @@ def device(): | ||||
|     return "cuda" | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_gelu_fast(kernel, device): | ||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||
|     y = torch.empty_like(x) | ||||
| @ -36,6 +56,100 @@ def test_gelu_fast(kernel, device): | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_local_kernel(local_kernel, device): | ||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||
|     y = torch.empty_like(x) | ||||
|  | ||||
|     local_kernel.gelu_fast(y, x) | ||||
|  | ||||
|     expected = torch.tensor( | ||||
|         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||
|         device=device, | ||||
|         dtype=torch.float16, | ||||
|     ) | ||||
|  | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_local_kernel_path_types(local_kernel_path, device): | ||||
|     package_name, path = local_kernel_path | ||||
|  | ||||
|     # Top-level repo path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071 | ||||
|     kernel = get_local_kernel(path.parent.parent, package_name) | ||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||
|     y = torch.empty_like(x) | ||||
|  | ||||
|     kernel.gelu_fast(y, x) | ||||
|     expected = torch.tensor( | ||||
|         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||
|         device=device, | ||||
|         dtype=torch.float16, | ||||
|     ) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|     # Build directory path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build | ||||
|     kernel = get_local_kernel(path.parent.parent / "build", package_name) | ||||
|     y = torch.empty_like(x) | ||||
|     kernel.gelu_fast(y, x) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|     # Explicit package path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux | ||||
|     kernel = get_local_kernel(path, package_name) | ||||
|     y = torch.empty_like(x) | ||||
|     kernel.gelu_fast(y, x) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.darwin_only | ||||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) | ||||
| def test_relu_metal(metal_kernel, dtype): | ||||
|     x = torch.arange(-10, 10, dtype=dtype, device="mps") | ||||
|     y = metal_kernel.relu(x) | ||||
|     assert torch.allclose(y, torch.relu(x)) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @pytest.mark.parametrize( | ||||
|     "kernel_exists", | ||||
|     [ | ||||
|         ("kernels-community/activation", "main", True), | ||||
|         ("kernels-community/triton-layer-norm", "main", True), | ||||
|         # Repo only contains Torch 2.4 kernels (and we don't | ||||
|         # support/test against this version). | ||||
|         ("kernels-test/only-torch-2.4", "main", False), | ||||
|         ("google-bert/bert-base-uncased", "87565a309", False), | ||||
|     ], | ||||
| ) | ||||
| def test_has_kernel(kernel_exists): | ||||
|     repo_id, revision, kernel = kernel_exists | ||||
|     assert has_kernel(repo_id, revision=revision) == kernel | ||||
|  | ||||
|  | ||||
| def test_version(): | ||||
|     kernel = get_kernel("kernels-test/versions") | ||||
|     assert kernel.version() == "0.2.0" | ||||
|     kernel = get_kernel("kernels-test/versions", version="<1.0.0") | ||||
|     assert kernel.version() == "0.2.0" | ||||
|     kernel = get_kernel("kernels-test/versions", version="<0.2.0") | ||||
|     assert kernel.version() == "0.1.1" | ||||
|     kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0") | ||||
|     assert kernel.version() == "0.1.1" | ||||
|  | ||||
|     with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): | ||||
|         get_kernel("kernels-test/versions", version=">0.2.0") | ||||
|  | ||||
|     with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): | ||||
|         kernel = get_kernel( | ||||
|             "kernels-test/versions", revision="v0.1.0", version="<1.0.0" | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_universal_kernel(universal_kernel): | ||||
|     torch.manual_seed(0) | ||||
|     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||
|  | ||||
| @ -16,18 +16,21 @@ def device(): | ||||
|     return "cuda" | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_gelu_small(kernel, device, benchmark): | ||||
|     x = torch.randn(32, 32, dtype=torch.float16, device=device) | ||||
|     y = torch.empty_like(x) | ||||
|     benchmark(kernel.gelu_fast, y, x) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_gelu_medium(kernel, device, benchmark): | ||||
|     x = torch.randn(128, 128, dtype=torch.float16, device=device) | ||||
|     y = torch.empty_like(x) | ||||
|     benchmark(kernel.gelu_fast, y, x) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_gelu_large(kernel, device, benchmark): | ||||
|     x = torch.randn(512, 512, dtype=torch.float16, device=device) | ||||
|     y = torch.empty_like(x) | ||||
|  | ||||
							
								
								
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | ||||
| import inspect | ||||
|  | ||||
| import pytest | ||||
| from mktestdocs import check_docstring, get_codeblock_members | ||||
|  | ||||
| import kernels | ||||
|  | ||||
|  | ||||
| def all_public_functions(): | ||||
|     function_list = inspect.getmembers(kernels, inspect.isfunction) | ||||
|     return [func for _, func in function_list] | ||||
|  | ||||
|  | ||||
| def all_public_classes(): | ||||
|     class_list = inspect.getmembers(kernels, inspect.isclass) | ||||
|     return [cls for _, cls in class_list] | ||||
|  | ||||
|  | ||||
| def all_public_class_members(): | ||||
|     members = get_codeblock_members(*all_public_classes()) | ||||
|     return members | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @pytest.mark.parametrize( | ||||
|     "func", | ||||
|     all_public_functions(), | ||||
|     ids=lambda d: d.__name__, | ||||
| ) | ||||
| def test_func_docstring(func): | ||||
|     check_docstring(obj=func) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @pytest.mark.parametrize( | ||||
|     "cls", | ||||
|     all_public_classes(), | ||||
|     ids=lambda d: d.__name__, | ||||
| ) | ||||
| def test_class_docstring(cls): | ||||
|     check_docstring(obj=cls) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @pytest.mark.parametrize( | ||||
|     "member", all_public_class_members(), ids=lambda d: d.__qualname__ | ||||
| ) | ||||
| def test_member_docstring(member): | ||||
|     check_docstring(member) | ||||
							
								
								
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,230 @@ | ||||
| import random | ||||
| from typing import Generic, List, Optional, Tuple, TypeVar | ||||
|  | ||||
| import pytest | ||||
|  | ||||
| from kernels._interval_tree import IntervalTree, _Node | ||||
|  | ||||
| T = TypeVar("T") | ||||
|  | ||||
|  | ||||
| class SimpleIntervalStore(Generic[T]): | ||||
|     """A simple O(n) implementation that stores intervals in a list.""" | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.intervals: List[Tuple[int, int, T]] = [] | ||||
|  | ||||
|     def insert(self, start: int, end: int, data: T) -> None: | ||||
|         """Insert an interval into the store.""" | ||||
|         # Replace data if the interval already exists. | ||||
|         for i, (existing_start, existing_end, existing_data) in enumerate( | ||||
|             self.intervals | ||||
|         ): | ||||
|             if existing_start == start and existing_end == end: | ||||
|                 self.intervals[i] = (start, end, data) | ||||
|                 return | ||||
|  | ||||
|         self.intervals.append((start, end, data)) | ||||
|  | ||||
|     def find_smallest_interval(self, point: int) -> Optional[T]: | ||||
|         """Find the best match using linear search.""" | ||||
|         matches = [] | ||||
|         for start, end, data in self.intervals: | ||||
|             if start <= point <= end: | ||||
|                 matches.append((start, end, data)) | ||||
|  | ||||
|         if not matches: | ||||
|             return None | ||||
|  | ||||
|         # Return the smallest interval, sort by memory location when | ||||
|         # there are multiple matches with the same interval size. This | ||||
|         # mirrors the ordering in the intervan tree. | ||||
|         best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2]))) | ||||
|         return best_match[2] | ||||
|  | ||||
|  | ||||
| def is_balanced(tree: IntervalTree[T]) -> bool: | ||||
|     """Check if the AVL tree is properly balanced.""" | ||||
|  | ||||
|     def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]: | ||||
|         if node is None: | ||||
|             return True, 0 | ||||
|  | ||||
|         # Left and right subtrees should be balanced. | ||||
|         left_balanced, left_height = check_balance(node.left) | ||||
|         if not left_balanced: | ||||
|             return False, -1 | ||||
|  | ||||
|         right_balanced, right_height = check_balance(node.right) | ||||
|         if not right_balanced: | ||||
|             return False, -1 | ||||
|  | ||||
|         # The difference in height should not exceed 1. | ||||
|         if abs(left_height - right_height) > 1: | ||||
|             return False, -1 | ||||
|  | ||||
|         # Check if the height is correct. | ||||
|         expected_height = 1 + max(left_height, right_height) | ||||
|         if node.height != expected_height: | ||||
|             return False, -1 | ||||
|  | ||||
|         return True, expected_height | ||||
|  | ||||
|     balanced, _ = check_balance(tree.root) | ||||
|     return balanced | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def populated_tree() -> IntervalTree[str]: | ||||
|     """Provides a pre-populated IntervalTree for testing.""" | ||||
|     tree = IntervalTree[str]() | ||||
|     kernels = [ | ||||
|         (80, 89, "Kernel_A_General_80_89"), | ||||
|         (86, 89, "Kernel_B_Ampere_86_89"), | ||||
|         (80, 86, "Kernel_C_Older_Ampere_80_86"), | ||||
|         (70, 75, "Kernel_D_Volta_70_75"), | ||||
|         (86, 87, "Kernel_E_Specific_86_87"), | ||||
|     ] | ||||
|     for start, end, name in kernels: | ||||
|         tree.insert(start, end, name) | ||||
|     return tree | ||||
|  | ||||
|  | ||||
| def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree): | ||||
|     # Check that the smallest inteval is selected when there are | ||||
|     # multiple matching intervals. | ||||
|     assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87" | ||||
|  | ||||
|  | ||||
| def test_find_single_match(populated_tree): | ||||
|     assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75" | ||||
|     assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75" | ||||
|  | ||||
|  | ||||
| def test_no_match_outside_all_ranges(populated_tree): | ||||
|     # Check that no interval is found when the value is out of range | ||||
|     # (too small/too large). | ||||
|     assert populated_tree.find_smallest_interval(65) is None | ||||
|     assert populated_tree.find_smallest_interval(95) is None | ||||
|  | ||||
|  | ||||
| def test_no_match_in_gap_between_ranges(populated_tree): | ||||
|     # Check that no interval is found when the value is between two | ||||
|     # intervals. | ||||
|     assert populated_tree.find_smallest_interval(78) is None | ||||
|  | ||||
|  | ||||
| def test_boundary_conditions_start_and_end(populated_tree): | ||||
|     # Test exact upper/lower bounds of intervals. | ||||
|     assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86" | ||||
|     assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89" | ||||
|  | ||||
|  | ||||
| def test_empty_tree(): | ||||
|     # Searching in an empty tree should return None. | ||||
|     empty_tree = IntervalTree[str]() | ||||
|     assert empty_tree.find_smallest_interval(100) is None | ||||
|  | ||||
|  | ||||
| def test_multiple_equally_specific_matches(): | ||||
|     # Check that we pick the match in a stable way when there is are | ||||
|     # multiple matching intervals with the same size. | ||||
|     tree = IntervalTree[str]() | ||||
|     str1 = "First_Narrow_Kernel" | ||||
|     str2 = "Second_Narrow_Kernel" | ||||
|     tree.insert(10, 20, "Wide_Kernel") | ||||
|     tree.insert(12, 17, str1) | ||||
|     tree.insert(14, 19, str2) | ||||
|  | ||||
|     if id(str1) < id(str2): | ||||
|         assert tree.find_smallest_interval(15) == str1 | ||||
|     else: | ||||
|         assert tree.find_smallest_interval(15) == str2 | ||||
|  | ||||
|  | ||||
| def test_property_based_interval_tree(): | ||||
|     # Quick-check property-based testing: | ||||
|     # | ||||
|     # - Verify that the tree is balanced after each insertion. | ||||
|     # - Verify the query against a simple list-based implementation. | ||||
|  | ||||
|     random.seed(42)  # For reproducible tests | ||||
|  | ||||
|     test_points = list(range(0, 101)) | ||||
|  | ||||
|     for _ in range(5): | ||||
|         tree = IntervalTree[str]() | ||||
|         simple = SimpleIntervalStore[str]() | ||||
|  | ||||
|         intervals = [] | ||||
|         for i in range(100): | ||||
|             start = random.randint(0, 90) | ||||
|             end = random.randint(start, 100) | ||||
|             data = f"interval_{i}_s{start}_e{end}" | ||||
|             intervals.append((start, end, data)) | ||||
|  | ||||
|         for i, (start, end, data) in enumerate(intervals): | ||||
|             tree.insert(start, end, data) | ||||
|             simple.insert(start, end, data) | ||||
|  | ||||
|             # Check that tree is still balanced | ||||
|             assert is_balanced( | ||||
|                 tree | ||||
|             ), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})" | ||||
|  | ||||
|             for point in test_points: | ||||
|                 tree_result = tree.find_smallest_interval(point) | ||||
|                 simple_result = simple.find_smallest_interval(point) | ||||
|  | ||||
|                 assert tree_result == simple_result, ( | ||||
|                     f"Mismatch for point {point} after inserting {i+1} intervals. " | ||||
|                     f"Tree: {tree_result}, Simple: {simple_result}. " | ||||
|                     f"Last inserted: ({start}, {end})" | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| def test_property_based_edge_cases(): | ||||
|     random.seed(123) | ||||
|  | ||||
|     tree = IntervalTree[str]() | ||||
|     simple = SimpleIntervalStore[str]() | ||||
|  | ||||
|     # Single-point intervals. | ||||
|     for i in range(10): | ||||
|         point = random.randint(0, 100) | ||||
|         data = f"single_point_{i}_{point}" | ||||
|         tree.insert(point, point, data) | ||||
|         simple.insert(point, point, data) | ||||
|  | ||||
|         assert is_balanced( | ||||
|             tree | ||||
|         ), f"Tree unbalanced after inserting single point {point}" | ||||
|  | ||||
|         # Test the exact point and neighbors | ||||
|         for test_point in [point - 1, point, point + 1]: | ||||
|             if 0 <= test_point <= 100: | ||||
|                 tree_result = tree.find_smallest_interval(test_point) | ||||
|                 simple_result = simple.find_smallest_interval(test_point) | ||||
|                 assert tree_result == simple_result | ||||
|  | ||||
|  | ||||
| def test_unique_intervals_override(): | ||||
|     """Test that inserting an interval with the same start/end overrides the previous value.""" | ||||
|     tree = IntervalTree[str]() | ||||
|  | ||||
|     tree.insert(10, 20, "original_value") | ||||
|     assert tree.find_smallest_interval(15) == "original_value" | ||||
|  | ||||
|     tree.insert(10, 20, "new_value") | ||||
|     assert tree.find_smallest_interval(15) == "new_value" | ||||
|  | ||||
|     tree.insert(10, 25, "different_interval") | ||||
|     results = tree.search(15) | ||||
|     assert "new_value" in results | ||||
|     assert "different_interval" in results | ||||
|     assert len(results) == 2 | ||||
|  | ||||
|     tree.insert(10, 20, "final_value") | ||||
|     assert tree.find_smallest_interval(15) == "final_value" | ||||
|  | ||||
|     assert is_balanced(tree) | ||||
| @ -1,8 +1,18 @@ | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
|  | ||||
| import pytest | ||||
| import torch.nn as nn | ||||
|  | ||||
| from kernels import load_kernel | ||||
| from kernels.cli import download_kernels | ||||
| from kernels.layer import ( | ||||
|     LockedLayerRepository, | ||||
|     Mode, | ||||
|     kernelize, | ||||
|     use_kernel_forward_from_hub, | ||||
|     use_kernel_mapping, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Mock download arguments class. | ||||
| @ -17,8 +27,35 @@ def test_download_all_hash_validation(): | ||||
|     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_load_locked(): | ||||
|     project_dir = Path(__file__).parent / "kernel_locking" | ||||
|     # Also validates that hashing works correctly. | ||||
|     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) | ||||
|     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_layer_locked(): | ||||
|     project_dir = Path(__file__).parent / "layer_locking" | ||||
|  | ||||
|     @use_kernel_forward_from_hub("Version") | ||||
|     class Version(nn.Module): | ||||
|         def forward(self) -> str: | ||||
|             return "0.0.0" | ||||
|  | ||||
|     version = Version() | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 "cuda": LockedLayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                     lockfile=project_dir / "kernels.lock", | ||||
|                 ) | ||||
|             }, | ||||
|         } | ||||
|     ): | ||||
|         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|         assert version() == "0.1.1" | ||||
|  | ||||
							
								
								
									
										117
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | ||||
| import logging | ||||
| import os | ||||
| import re | ||||
| import tempfile | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
|  | ||||
| import pytest | ||||
| from huggingface_hub import delete_repo, model_info | ||||
|  | ||||
| from kernels.cli import upload_kernels | ||||
|  | ||||
| REPO_ID = "valid_org/kernels-upload-test" | ||||
|  | ||||
|  | ||||
| PY_CONTENT = """\ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| def main(): | ||||
|     print("Hello from torch-universal!") | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| """ | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class UploadArgs: | ||||
|     kernel_dir: None | ||||
|     repo_id: None | ||||
|     private: False | ||||
|     branch: None | ||||
|  | ||||
|  | ||||
| def next_filename(path: Path) -> Path: | ||||
|     """ | ||||
|     Given a path like foo_2050.py, return foo_2051.py. | ||||
|     """ | ||||
|     m = re.match(r"^(.*?)(\d+)(\.py)$", path.name) | ||||
|     if not m: | ||||
|         raise ValueError( | ||||
|             f"Filename {path.name!r} does not match pattern <prefix>_<number>.py" | ||||
|         ) | ||||
|  | ||||
|     prefix, number, suffix = m.groups() | ||||
|     new_number = str(int(number) + 1).zfill(len(number)) | ||||
|     return path.with_name(f"{prefix}{new_number}{suffix}") | ||||
|  | ||||
|  | ||||
| def get_filename_to_change(repo_filenames): | ||||
|     for f in repo_filenames: | ||||
|         if "foo" in f and f.endswith(".py"): | ||||
|             filename_to_change = os.path.basename(f) | ||||
|             break | ||||
|     assert filename_to_change | ||||
|     return filename_to_change | ||||
|  | ||||
|  | ||||
| def get_filenames_from_a_repo(repo_id: str) -> List[str]: | ||||
|     try: | ||||
|         repo_info = model_info(repo_id=repo_id, files_metadata=True) | ||||
|         repo_siblings = repo_info.siblings | ||||
|         if repo_siblings is not None: | ||||
|             return [f.rfilename for f in repo_siblings] | ||||
|         else: | ||||
|             raise ValueError("No repo siblings found.") | ||||
|     except Exception as e: | ||||
|         logging.error(f"Error connecting to the Hub: {e}.") | ||||
|  | ||||
|  | ||||
| @pytest.mark.token | ||||
| @pytest.mark.is_staging_test | ||||
| @pytest.mark.parametrize("branch", (None, "foo")) | ||||
| def test_kernel_upload_works_as_expected(branch): | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         script_path = build_dir / "foo.py" | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False, branch)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     assert any(str(script_path.name) for f in repo_filenames) | ||||
|     delete_repo(repo_id=REPO_ID) | ||||
|  | ||||
|  | ||||
| @pytest.mark.token | ||||
| @pytest.mark.is_staging_test | ||||
| def test_kernel_upload_deletes_as_expected(): | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         script_path = build_dir / "foo_2025.py" | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     filename_to_change = get_filename_to_change(repo_filenames) | ||||
|  | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         changed_filename = next_filename(Path(filename_to_change)) | ||||
|         script_path = build_dir / changed_filename | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}" | ||||
|     assert not any( | ||||
|         str(filename_to_change) in k for k in repo_filenames | ||||
|     ), f"{repo_filenames=}" | ||||
|     delete_repo(repo_id=REPO_ID) | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	