mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-31 19:54:28 +08:00 
			
		
		
		
	Compare commits
	
		
			96 Commits
		
	
	
		
			fixup-arg-
			...
			upload-rev
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| b2e8703329 | |||
| 060c326d89 | |||
| 34a1932751 | |||
| e39eac09c1 | |||
| b0c431fee4 | |||
| 9a188eadbe | |||
| 457c7c1b8d | |||
| fb8cd99a2c | |||
| dfee307d54 | |||
| 93e5765611 | |||
| bf488208be | |||
| 2a14472e4c | |||
| 055a953552 | |||
| 692d5ad458 | |||
| 2139df57f4 | |||
| 8f9a77bb6a | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 | |||
| 07e5e8481a | |||
| 88f55d4728 | |||
| e801ebf332 | |||
| 0ae07f05fc | |||
| 7611021100 | |||
| 767e7ccf13 | |||
| 1caa4c1393 | |||
| da701bf58a | |||
| 703664ed31 | |||
| a8a6564fa7 | |||
| c89e0fa9b9 | |||
| 176a601178 | |||
| cfa0c76ddc | |||
| bcc29915f9 | |||
| 6fbff7a9cb | |||
| f7490bd0a9 | |||
| 8069e3bf0c | |||
| c540d1e1d6 | |||
| 967ac581b8 | |||
| 81088d44e8 | |||
| 4a04c005e3 | |||
| 6d3c6daf20 | |||
| 071900fd69 | |||
| 2d2c6b14e0 | |||
| 03edc573b1 | |||
| c841a6c90d | |||
| c7a343f195 | |||
| 8d838f947d | |||
| b87e6fadbe | |||
| fc935d9874 | |||
| 3622e1f8dd | |||
| a7f3b2e8ed | |||
| a6ab5d83ba | |||
| 4f9f1abfb9 | |||
| f94b7780a6 | |||
| bd28883775 | |||
| 498429e322 | |||
| 09c991af4b | |||
| bcf8df5875 | |||
| 239afff6f5 | |||
| c5ec6b900a | |||
| 3a635eaeea | |||
| 32ec496c5a | |||
| 848c6db87b | |||
| fabb8c52d1 | |||
| d66260dd83 | |||
| daac8078fc | |||
| fcb9a80ce6 | |||
| c25bb32e6e | |||
| 2036892762 | |||
| 0f0de049cf | |||
| 59597df03e | |||
| 5e938ede40 | |||
| cf530c283a | |||
| 437f910336 | |||
| 6f1a6067c8 | |||
| 1d14abcef0 | |||
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | |||
| 22eaba2826 | |||
| 9521ba79a0 | |||
| 9861a5bdef | 
							
								
								
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | |||||||
|  | name: Build documentation | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     branches: | ||||||
|  |       - main | ||||||
|  |       - doc-builder* | ||||||
|  |       - v*-release | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main | ||||||
|  |     with: | ||||||
|  |       commit_sha: ${{ github.sha }} | ||||||
|  |       package: kernels | ||||||
|  |     secrets: | ||||||
|  |       hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} | ||||||
							
								
								
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | |||||||
|  | name: Build PR Documentation | ||||||
|  |  | ||||||
|  | on: pull_request | ||||||
|  |  | ||||||
|  | concurrency: | ||||||
|  |   group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||||||
|  |   cancel-in-progress: true | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main | ||||||
|  |     with: | ||||||
|  |       commit_sha: ${{ github.event.pull_request.head.sha }} | ||||||
|  |       pr_number: ${{ github.event.number }} | ||||||
|  |       package: kernels | ||||||
							
								
								
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							| @ -8,3 +8,24 @@ jobs: | |||||||
|       - uses: actions/checkout@v4 |       - uses: actions/checkout@v4 | ||||||
|       - name: Run ruff |       - name: Run ruff | ||||||
|         uses: astral-sh/ruff-action@v3 |         uses: astral-sh/ruff-action@v3 | ||||||
|  |  | ||||||
|  |   black: | ||||||
|  |     name: Run black check | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     env: | ||||||
|  |       UV_PYTHON_PREFERENCE: only-managed | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v4 | ||||||
|  |  | ||||||
|  |       - name: Install uv and set the python version | ||||||
|  |         uses: astral-sh/setup-uv@v5 | ||||||
|  |         with: | ||||||
|  |           python-version: 3.12 | ||||||
|  |  | ||||||
|  |       - name: Install black | ||||||
|  |         run: uv pip install black | ||||||
|  |  | ||||||
|  |       - name: Check formatting | ||||||
|  |         run: | | ||||||
|  |           uv run black --check src | ||||||
|  |           uv run black --check tests | ||||||
|  | |||||||
							
								
								
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,120 @@ | |||||||
|  | name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI | ||||||
|  |  | ||||||
|  | on: push | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     name: Build distribution 📦 | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |  | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v4 | ||||||
|  |         with: | ||||||
|  |           persist-credentials: false | ||||||
|  |       - name: Set up Python | ||||||
|  |         uses: actions/setup-python@v5 | ||||||
|  |         with: | ||||||
|  |           python-version: "3.9" | ||||||
|  |       - name: Install pypa/build | ||||||
|  |         run: >- | ||||||
|  |           python3 -m | ||||||
|  |           pip install | ||||||
|  |           build | ||||||
|  |           --user | ||||||
|  |       - name: Build a binary wheel and a source tarball | ||||||
|  |         run: python3 -m build | ||||||
|  |       - name: Store the distribution packages | ||||||
|  |         uses: actions/upload-artifact@v4 | ||||||
|  |         with: | ||||||
|  |           name: python-package-distributions | ||||||
|  |           path: dist/ | ||||||
|  |  | ||||||
|  |   publish-to-pypi: | ||||||
|  |     name: >- | ||||||
|  |       Publish Python 🐍 distribution 📦 to PyPI | ||||||
|  |     if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes | ||||||
|  |     needs: | ||||||
|  |       - build | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     environment: | ||||||
|  |       name: pypi | ||||||
|  |       url: https://pypi.org/p/kernels | ||||||
|  |     permissions: | ||||||
|  |       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||||
|  |  | ||||||
|  |     steps: | ||||||
|  |       - name: Download all the dists | ||||||
|  |         uses: actions/download-artifact@v4 | ||||||
|  |         with: | ||||||
|  |           name: python-package-distributions | ||||||
|  |           path: dist/ | ||||||
|  |       - name: Publish distribution 📦 to PyPI | ||||||
|  |         uses: pypa/gh-action-pypi-publish@release/v1 | ||||||
|  |  | ||||||
|  |   github-release: | ||||||
|  |     name: >- | ||||||
|  |       Sign the Python 🐍 distribution 📦 with Sigstore | ||||||
|  |       and upload them to GitHub Release | ||||||
|  |     needs: | ||||||
|  |       - publish-to-pypi | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |  | ||||||
|  |     permissions: | ||||||
|  |       contents: write # IMPORTANT: mandatory for making GitHub Releases | ||||||
|  |       id-token: write # IMPORTANT: mandatory for sigstore | ||||||
|  |  | ||||||
|  |     steps: | ||||||
|  |       - name: Download all the dists | ||||||
|  |         uses: actions/download-artifact@v4 | ||||||
|  |         with: | ||||||
|  |           name: python-package-distributions | ||||||
|  |           path: dist/ | ||||||
|  |       - name: Sign the dists with Sigstore | ||||||
|  |         uses: sigstore/gh-action-sigstore-python@v3.0.0 | ||||||
|  |         with: | ||||||
|  |           inputs: >- | ||||||
|  |             ./dist/*.tar.gz | ||||||
|  |             ./dist/*.whl | ||||||
|  |       - name: Create GitHub Release | ||||||
|  |         env: | ||||||
|  |           GITHUB_TOKEN: ${{ github.token }} | ||||||
|  |         run: >- | ||||||
|  |           gh release create | ||||||
|  |           "$GITHUB_REF_NAME" | ||||||
|  |           --repo "$GITHUB_REPOSITORY" | ||||||
|  |           --notes "" | ||||||
|  |       - name: Upload artifact signatures to GitHub Release | ||||||
|  |         env: | ||||||
|  |           GITHUB_TOKEN: ${{ github.token }} | ||||||
|  |         # Upload to GitHub Release using the `gh` CLI. | ||||||
|  |         # `dist/` contains the built packages, and the | ||||||
|  |         # sigstore-produced signatures and certificates. | ||||||
|  |         run: >- | ||||||
|  |           gh release upload | ||||||
|  |           "$GITHUB_REF_NAME" dist/** | ||||||
|  |           --repo "$GITHUB_REPOSITORY" | ||||||
|  |  | ||||||
|  |   publish-to-testpypi: | ||||||
|  |     name: Publish Python 🐍 distribution 📦 to TestPyPI | ||||||
|  |     needs: | ||||||
|  |       - build | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |  | ||||||
|  |     environment: | ||||||
|  |       name: testpypi | ||||||
|  |       url: https://test.pypi.org/p/kernels | ||||||
|  |  | ||||||
|  |     permissions: | ||||||
|  |       id-token: write # IMPORTANT: mandatory for trusted publishing | ||||||
|  |  | ||||||
|  |     steps: | ||||||
|  |       - name: Download all the dists | ||||||
|  |         uses: actions/download-artifact@v4 | ||||||
|  |         with: | ||||||
|  |           name: python-package-distributions | ||||||
|  |           path: dist/ | ||||||
|  |       - name: Publish distribution 📦 to TestPyPI | ||||||
|  |         uses: pypa/gh-action-pypi-publish@release/v1 | ||||||
|  |         with: | ||||||
|  |           repository-url: https://test.pypi.org/legacy/ | ||||||
|  |           skip-existing: true # Only upload when the version is unique. | ||||||
							
								
								
									
										34
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										34
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -24,7 +24,7 @@ jobs: | |||||||
|       max-parallel: 4 |       max-parallel: 4 | ||||||
|       matrix: |       matrix: | ||||||
|         python-version: ["3.10", "3.12"] |         python-version: ["3.10", "3.12"] | ||||||
|         torch-version: ["2.5.1", "2.6.0"] |         torch-version: ["2.6.0", "2.7.0"] | ||||||
|  |  | ||||||
|     env: |     env: | ||||||
|       UV_PYTHON_PREFERENCE: only-managed |       UV_PYTHON_PREFERENCE: only-managed | ||||||
| @ -51,4 +51,34 @@ jobs: | |||||||
|         run: uv run mypy src/kernels |         run: uv run mypy src/kernels | ||||||
|  |  | ||||||
|       - name: Run tests |       - name: Run tests | ||||||
|         run: uv run pytest tests |         run: | | ||||||
|  |           uv run pytest tests | ||||||
|  |  | ||||||
|  |       - name: Run staging tests | ||||||
|  |         env: | ||||||
|  |           HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }} | ||||||
|  |         run: | | ||||||
|  |           HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/ | ||||||
|  |         if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0' | ||||||
|  |  | ||||||
|  |       - name: Check kernel conversion | ||||||
|  |         run: | | ||||||
|  |           uv pip install wheel | ||||||
|  |           uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1 | ||||||
|  |           uv pip install triton_layer_norm-0.0.1*.whl | ||||||
|  |           uv run python -c "import triton_layer_norm" | ||||||
|  |  | ||||||
|  |       - name: Check README generation | ||||||
|  |         # For now, just checks that generation doesn't fail. | ||||||
|  |         run: | | ||||||
|  |           uv run kernels generate-readme kernels-community/triton-layer-norm | ||||||
|  |  | ||||||
|  |       - name: Check kernel check | ||||||
|  |         run: | | ||||||
|  |           uv pip install kernel-abi-check | ||||||
|  |           kernels check kernels-community/activation | ||||||
|  |  | ||||||
|  |       - name: Import check without torch | ||||||
|  |         run: | | ||||||
|  |           uv pip uninstall torch | ||||||
|  |           python -c "import kernels" | ||||||
|  | |||||||
							
								
								
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | name: Upload PR Documentation | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   workflow_run: | ||||||
|  |     workflows: ["Build PR Documentation"] | ||||||
|  |     types: | ||||||
|  |       - completed | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main | ||||||
|  |     with: | ||||||
|  |       package_name: kernels | ||||||
|  |     secrets: | ||||||
|  |       hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} | ||||||
|  |       comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} | ||||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,201 @@ | |||||||
|  |                                  Apache License | ||||||
|  |                            Version 2.0, January 2004 | ||||||
|  |                         http://www.apache.org/licenses/ | ||||||
|  |  | ||||||
|  |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||||
|  |  | ||||||
|  |    1. Definitions. | ||||||
|  |  | ||||||
|  |       "License" shall mean the terms and conditions for use, reproduction, | ||||||
|  |       and distribution as defined by Sections 1 through 9 of this document. | ||||||
|  |  | ||||||
|  |       "Licensor" shall mean the copyright owner or entity authorized by | ||||||
|  |       the copyright owner that is granting the License. | ||||||
|  |  | ||||||
|  |       "Legal Entity" shall mean the union of the acting entity and all | ||||||
|  |       other entities that control, are controlled by, or are under common | ||||||
|  |       control with that entity. For the purposes of this definition, | ||||||
|  |       "control" means (i) the power, direct or indirect, to cause the | ||||||
|  |       direction or management of such entity, whether by contract or | ||||||
|  |       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||||
|  |       outstanding shares, or (iii) beneficial ownership of such entity. | ||||||
|  |  | ||||||
|  |       "You" (or "Your") shall mean an individual or Legal Entity | ||||||
|  |       exercising permissions granted by this License. | ||||||
|  |  | ||||||
|  |       "Source" form shall mean the preferred form for making modifications, | ||||||
|  |       including but not limited to software source code, documentation | ||||||
|  |       source, and configuration files. | ||||||
|  |  | ||||||
|  |       "Object" form shall mean any form resulting from mechanical | ||||||
|  |       transformation or translation of a Source form, including but | ||||||
|  |       not limited to compiled object code, generated documentation, | ||||||
|  |       and conversions to other media types. | ||||||
|  |  | ||||||
|  |       "Work" shall mean the work of authorship, whether in Source or | ||||||
|  |       Object form, made available under the License, as indicated by a | ||||||
|  |       copyright notice that is included in or attached to the work | ||||||
|  |       (an example is provided in the Appendix below). | ||||||
|  |  | ||||||
|  |       "Derivative Works" shall mean any work, whether in Source or Object | ||||||
|  |       form, that is based on (or derived from) the Work and for which the | ||||||
|  |       editorial revisions, annotations, elaborations, or other modifications | ||||||
|  |       represent, as a whole, an original work of authorship. For the purposes | ||||||
|  |       of this License, Derivative Works shall not include works that remain | ||||||
|  |       separable from, or merely link (or bind by name) to the interfaces of, | ||||||
|  |       the Work and Derivative Works thereof. | ||||||
|  |  | ||||||
|  |       "Contribution" shall mean any work of authorship, including | ||||||
|  |       the original version of the Work and any modifications or additions | ||||||
|  |       to that Work or Derivative Works thereof, that is intentionally | ||||||
|  |       submitted to Licensor for inclusion in the Work by the copyright owner | ||||||
|  |       or by an individual or Legal Entity authorized to submit on behalf of | ||||||
|  |       the copyright owner. For the purposes of this definition, "submitted" | ||||||
|  |       means any form of electronic, verbal, or written communication sent | ||||||
|  |       to the Licensor or its representatives, including but not limited to | ||||||
|  |       communication on electronic mailing lists, source code control systems, | ||||||
|  |       and issue tracking systems that are managed by, or on behalf of, the | ||||||
|  |       Licensor for the purpose of discussing and improving the Work, but | ||||||
|  |       excluding communication that is conspicuously marked or otherwise | ||||||
|  |       designated in writing by the copyright owner as "Not a Contribution." | ||||||
|  |  | ||||||
|  |       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||||
|  |       on behalf of whom a Contribution has been received by Licensor and | ||||||
|  |       subsequently incorporated within the Work. | ||||||
|  |  | ||||||
|  |    2. Grant of Copyright License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       copyright license to reproduce, prepare Derivative Works of, | ||||||
|  |       publicly display, publicly perform, sublicense, and distribute the | ||||||
|  |       Work and such Derivative Works in Source or Object form. | ||||||
|  |  | ||||||
|  |    3. Grant of Patent License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       (except as stated in this section) patent license to make, have made, | ||||||
|  |       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||||
|  |       where such license applies only to those patent claims licensable | ||||||
|  |       by such Contributor that are necessarily infringed by their | ||||||
|  |       Contribution(s) alone or by combination of their Contribution(s) | ||||||
|  |       with the Work to which such Contribution(s) was submitted. If You | ||||||
|  |       institute patent litigation against any entity (including a | ||||||
|  |       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||||
|  |       or a Contribution incorporated within the Work constitutes direct | ||||||
|  |       or contributory patent infringement, then any patent licenses | ||||||
|  |       granted to You under this License for that Work shall terminate | ||||||
|  |       as of the date such litigation is filed. | ||||||
|  |  | ||||||
|  |    4. Redistribution. You may reproduce and distribute copies of the | ||||||
|  |       Work or Derivative Works thereof in any medium, with or without | ||||||
|  |       modifications, and in Source or Object form, provided that You | ||||||
|  |       meet the following conditions: | ||||||
|  |  | ||||||
|  |       (a) You must give any other recipients of the Work or | ||||||
|  |           Derivative Works a copy of this License; and | ||||||
|  |  | ||||||
|  |       (b) You must cause any modified files to carry prominent notices | ||||||
|  |           stating that You changed the files; and | ||||||
|  |  | ||||||
|  |       (c) You must retain, in the Source form of any Derivative Works | ||||||
|  |           that You distribute, all copyright, patent, trademark, and | ||||||
|  |           attribution notices from the Source form of the Work, | ||||||
|  |           excluding those notices that do not pertain to any part of | ||||||
|  |           the Derivative Works; and | ||||||
|  |  | ||||||
|  |       (d) If the Work includes a "NOTICE" text file as part of its | ||||||
|  |           distribution, then any Derivative Works that You distribute must | ||||||
|  |           include a readable copy of the attribution notices contained | ||||||
|  |           within such NOTICE file, excluding those notices that do not | ||||||
|  |           pertain to any part of the Derivative Works, in at least one | ||||||
|  |           of the following places: within a NOTICE text file distributed | ||||||
|  |           as part of the Derivative Works; within the Source form or | ||||||
|  |           documentation, if provided along with the Derivative Works; or, | ||||||
|  |           within a display generated by the Derivative Works, if and | ||||||
|  |           wherever such third-party notices normally appear. The contents | ||||||
|  |           of the NOTICE file are for informational purposes only and | ||||||
|  |           do not modify the License. You may add Your own attribution | ||||||
|  |           notices within Derivative Works that You distribute, alongside | ||||||
|  |           or as an addendum to the NOTICE text from the Work, provided | ||||||
|  |           that such additional attribution notices cannot be construed | ||||||
|  |           as modifying the License. | ||||||
|  |  | ||||||
|  |       You may add Your own copyright statement to Your modifications and | ||||||
|  |       may provide additional or different license terms and conditions | ||||||
|  |       for use, reproduction, or distribution of Your modifications, or | ||||||
|  |       for any such Derivative Works as a whole, provided Your use, | ||||||
|  |       reproduction, and distribution of the Work otherwise complies with | ||||||
|  |       the conditions stated in this License. | ||||||
|  |  | ||||||
|  |    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||||
|  |       any Contribution intentionally submitted for inclusion in the Work | ||||||
|  |       by You to the Licensor shall be under the terms and conditions of | ||||||
|  |       this License, without any additional terms or conditions. | ||||||
|  |       Notwithstanding the above, nothing herein shall supersede or modify | ||||||
|  |       the terms of any separate license agreement you may have executed | ||||||
|  |       with Licensor regarding such Contributions. | ||||||
|  |  | ||||||
|  |    6. Trademarks. This License does not grant permission to use the trade | ||||||
|  |       names, trademarks, service marks, or product names of the Licensor, | ||||||
|  |       except as required for reasonable and customary use in describing the | ||||||
|  |       origin of the Work and reproducing the content of the NOTICE file. | ||||||
|  |  | ||||||
|  |    7. Disclaimer of Warranty. Unless required by applicable law or | ||||||
|  |       agreed to in writing, Licensor provides the Work (and each | ||||||
|  |       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||||
|  |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||||
|  |       implied, including, without limitation, any warranties or conditions | ||||||
|  |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||||
|  |       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||||
|  |       appropriateness of using or redistributing the Work and assume any | ||||||
|  |       risks associated with Your exercise of permissions under this License. | ||||||
|  |  | ||||||
|  |    8. Limitation of Liability. In no event and under no legal theory, | ||||||
|  |       whether in tort (including negligence), contract, or otherwise, | ||||||
|  |       unless required by applicable law (such as deliberate and grossly | ||||||
|  |       negligent acts) or agreed to in writing, shall any Contributor be | ||||||
|  |       liable to You for damages, including any direct, indirect, special, | ||||||
|  |       incidental, or consequential damages of any character arising as a | ||||||
|  |       result of this License or out of the use or inability to use the | ||||||
|  |       Work (including but not limited to damages for loss of goodwill, | ||||||
|  |       work stoppage, computer failure or malfunction, or any and all | ||||||
|  |       other commercial damages or losses), even if such Contributor | ||||||
|  |       has been advised of the possibility of such damages. | ||||||
|  |  | ||||||
|  |    9. Accepting Warranty or Additional Liability. While redistributing | ||||||
|  |       the Work or Derivative Works thereof, You may choose to offer, | ||||||
|  |       and charge a fee for, acceptance of support, warranty, indemnity, | ||||||
|  |       or other liability obligations and/or rights consistent with this | ||||||
|  |       License. However, in accepting such obligations, You may act only | ||||||
|  |       on Your own behalf and on Your sole responsibility, not on behalf | ||||||
|  |       of any other Contributor, and only if You agree to indemnify, | ||||||
|  |       defend, and hold each Contributor harmless for any liability | ||||||
|  |       incurred by, or claims asserted against, such Contributor by reason | ||||||
|  |       of your accepting any such warranty or additional liability. | ||||||
|  |  | ||||||
|  |    END OF TERMS AND CONDITIONS | ||||||
|  |  | ||||||
|  |    APPENDIX: How to apply the Apache License to your work. | ||||||
|  |  | ||||||
|  |       To apply the Apache License to your work, attach the following | ||||||
|  |       boilerplate notice, with the fields enclosed by brackets "[]" | ||||||
|  |       replaced with your own identifying information. (Don't include | ||||||
|  |       the brackets!)  The text should be enclosed in the appropriate | ||||||
|  |       comment syntax for the file format. We also recommend that a | ||||||
|  |       file or class name and description of purpose be included on the | ||||||
|  |       same "printed page" as the copyright notice for easier | ||||||
|  |       identification within third-party archives. | ||||||
|  |  | ||||||
|  |    Copyright [yyyy] [name of copyright owner] | ||||||
|  |  | ||||||
|  |    Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |    you may not use this file except in compliance with the License. | ||||||
|  |    You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |        http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  |    Unless required by applicable law or agreed to in writing, software | ||||||
|  |    distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |    See the License for the specific language governing permissions and | ||||||
|  |    limitations under the License. | ||||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | |||||||
|  | .PHONY: style | ||||||
|  |  | ||||||
|  | export check_dirs := src examples tests | ||||||
|  |  | ||||||
|  | style: | ||||||
|  | 	black ${check_dirs} | ||||||
|  | 	isort ${check_dirs} | ||||||
|  | 	ruff check ${check_dirs} --fix | ||||||
							
								
								
									
										23
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								README.md
									
									
									
									
									
								
							| @ -1,5 +1,16 @@ | |||||||
| # kernels | # kernels | ||||||
|  |  | ||||||
|  | <div align="center"> | ||||||
|  | <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||||
|  | <p align="center"> | ||||||
|  |     <a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a> | ||||||
|  |     <a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a> | ||||||
|  |     <a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a> | ||||||
|  |    | ||||||
|  | </p> | ||||||
|  | </div> | ||||||
|  | <hr/> | ||||||
|  |  | ||||||
| The Kernel Hub allows Python libraries and applications to load compute | The Kernel Hub allows Python libraries and applications to load compute | ||||||
| kernels directly from the [Hub](https://hf.co/). To support this kind | kernels directly from the [Hub](https://hf.co/). To support this kind | ||||||
| of dynamic loading, Hub kernels differ from traditional Python kernel | of dynamic loading, Hub kernels differ from traditional Python kernel | ||||||
| @ -45,8 +56,12 @@ the Hub. | |||||||
|  |  | ||||||
| ## 📚 Documentation | ## 📚 Documentation | ||||||
|  |  | ||||||
| - [Using layers](docs/layers.md) | - [Introduction](docs/source/index.md) | ||||||
| - [Locking kernel versions](docs/locking.md) | - [Installation](docs/source/installation.md) | ||||||
| - [Using kernels in a Docker container](docs/docker.md) | - [Basic usage](docs/source/basic-usage.md) | ||||||
| - [Kernel requirements](docs/kernel-requirements.md) | - [Using layers](docs/source/layers.md) | ||||||
|  | - [Locking kernel/layer versions](docs/source/locking.md) | ||||||
|  | - [Environment variables](docs/source/env.md) | ||||||
|  | - [Kernel requirements](docs/source/kernel-requirements.md) | ||||||
|  | - [Frequently Asked Questions](docs/source/faq.md) | ||||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
|  | |||||||
| @ -1,8 +0,0 @@ | |||||||
| # Using kernels in a Docker container |  | ||||||
|  |  | ||||||
| build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: |  | ||||||
|  |  | ||||||
| ```bash |  | ||||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . |  | ||||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference |  | ||||||
| ``` |  | ||||||
| @ -1,79 +0,0 @@ | |||||||
| # Layers |  | ||||||
|  |  | ||||||
| A kernel can provide layers in addition to kernel functions. A layer from |  | ||||||
| the Hub can replace the `forward` method of an existing layer for a certain |  | ||||||
| device type. This makes it possible to provide more performant kernels for |  | ||||||
| existing layers. |  | ||||||
|  |  | ||||||
| See [Kernel requirements](kernel-requirements.md) for more information the |  | ||||||
| requirements of Hub layers. |  | ||||||
|  |  | ||||||
| ## Making a layer extensible with kernels from the hub |  | ||||||
|  |  | ||||||
| ### Using a decorator |  | ||||||
|  |  | ||||||
| A layer can be made extensible with the `use_kernel_forward_from_hub` |  | ||||||
| decorator. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| @use_kernel_forward_from_hub("SiluAndMul") |  | ||||||
| class SiluAndMul(nn.Module): |  | ||||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: |  | ||||||
|         d = input.shape[-1] // 2 |  | ||||||
|         return F.silu(input[..., :d]) * input[..., d:] |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The decorator changes the layer, so that other implementations of the `forward` |  | ||||||
| method can be registered using the name `SiluAndMul`. |  | ||||||
|  |  | ||||||
| ### External layers |  | ||||||
|  |  | ||||||
| An existing layer that does not (yet) have the `use_kernel_forward_from_hub` |  | ||||||
| decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function. |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| from somelibrary import SiluAndMul |  | ||||||
|  |  | ||||||
| replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") |  | ||||||
| register_kernel_mapping(kernel_layer_mapping) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The `register_kernel_mapping` call maps the name `SiluAndMul` to actual |  | ||||||
| hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer) |  | ||||||
| section for more information. |  | ||||||
|  |  | ||||||
| **Warning:** we strongly recommend using layers with a decorator, since |  | ||||||
| it signifies that the maintainer intends to keep the `forward` signature |  | ||||||
| compatible with layers from the hub. |  | ||||||
|  |  | ||||||
| ## Registering a hub kernel for a layer |  | ||||||
|  |  | ||||||
| Once a layer is made extensible, users can register hub kernels for it |  | ||||||
| by name using the `register_kernel_mapping` function. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| kernel_layer_mapping = { |  | ||||||
|     "SiluAndMul": { |  | ||||||
|         "cuda": LayerRepository( |  | ||||||
|             repo_id="kernels-community/activation", |  | ||||||
|             layer_name="SiluAndMul", |  | ||||||
|             revision="layers", |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| register_kernel_mapping(kernel_layer_mapping) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| This will register the kernel mapping in the current context, which is |  | ||||||
| normally global. It is recommended to scope the mapping to where it is |  | ||||||
| used with the `use_kernel_mapping` context manager: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| with use_kernel_mapping(kernel_layer_mapping): |  | ||||||
|     # Use the layer for which the mapping is applied. |  | ||||||
|     ... |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| This ensures that the mapping is not active anymore outside the |  | ||||||
| `with`-scope. |  | ||||||
							
								
								
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | |||||||
|  | - sections: | ||||||
|  |     - local: index | ||||||
|  |       title: Introduction | ||||||
|  |     - local: installation | ||||||
|  |       title: Installation | ||||||
|  |   title: Getting started | ||||||
|  | - sections: | ||||||
|  |     - local: basic-usage | ||||||
|  |       title: Basic Usage | ||||||
|  |     - local: layers | ||||||
|  |       title: Using Layers | ||||||
|  |     - local: locking | ||||||
|  |       title: Locking Kernel Versions | ||||||
|  |     - local: env | ||||||
|  |       title: Environment Variables | ||||||
|  |     - local: faq | ||||||
|  |       title: FAQ | ||||||
|  |   title: Usage Guide | ||||||
|  | - sections: | ||||||
|  |     - local: api/kernels | ||||||
|  |       title: Kernels | ||||||
|  |     - local: api/layers | ||||||
|  |       title: Layers | ||||||
|  |     - local: cli | ||||||
|  |       title: Kernels CLI | ||||||
|  |   title: API Reference | ||||||
|  | - sections: | ||||||
|  |     - local: kernel-requirements | ||||||
|  |       title: Kernel Requirements | ||||||
|  |   title: Developer Guide | ||||||
							
								
								
									
										25
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | |||||||
|  | # Kernels API Reference | ||||||
|  |  | ||||||
|  | ## Main Functions | ||||||
|  |  | ||||||
|  | ### get_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.get_kernel | ||||||
|  |  | ||||||
|  | ### get_local_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.get_local_kernel | ||||||
|  |  | ||||||
|  | ### has_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.has_kernel | ||||||
|  |  | ||||||
|  | ## Loading locked kernels | ||||||
|  |  | ||||||
|  | ### load_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.load_kernel | ||||||
|  |  | ||||||
|  | ### get_locked_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.get_locked_kernel | ||||||
							
								
								
									
										49
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | |||||||
|  | # Layers API Reference | ||||||
|  |  | ||||||
|  | ## Making layers kernel-aware | ||||||
|  |  | ||||||
|  | ### use_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.use_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | ### replace_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.replace_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | ## Registering kernel mappings | ||||||
|  |  | ||||||
|  | ### use_kernel_mapping | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.use_kernel_mapping | ||||||
|  |  | ||||||
|  | ### register_kernel_mapping | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.register_kernel_mapping | ||||||
|  |  | ||||||
|  | ## Kernelizing a model | ||||||
|  |  | ||||||
|  | ### kernelize | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.kernelize | ||||||
|  |  | ||||||
|  | ## Classes | ||||||
|  |  | ||||||
|  | ### Device | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.Device | ||||||
|  |  | ||||||
|  | ### Mode | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.Mode | ||||||
|  |  | ||||||
|  | ### LayerRepository | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.LayerRepository | ||||||
|  |  | ||||||
|  | ### LocalLayerRepository | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.LocalLayerRepository | ||||||
|  |  | ||||||
|  | ### LockedLayerRepository | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.LockedLayerRepository | ||||||
							
								
								
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | |||||||
|  | # Basic Usage | ||||||
|  |  | ||||||
|  | ## Loading Kernels | ||||||
|  |  | ||||||
|  | Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | import torch | ||||||
|  | from kernels import get_kernel | ||||||
|  |  | ||||||
|  | # Download optimized kernels from the Hugging Face hub | ||||||
|  | activation = get_kernel("kernels-community/activation") | ||||||
|  |  | ||||||
|  | # Create a random tensor | ||||||
|  | x = torch.randn((10, 10), dtype=torch.float16, device="cuda") | ||||||
|  |  | ||||||
|  | # Run the kernel | ||||||
|  | y = torch.empty_like(x) | ||||||
|  | activation.gelu_fast(y, x) | ||||||
|  |  | ||||||
|  | print(y) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Using version bounds | ||||||
|  |  | ||||||
|  | Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||||
|  | You can specify which version to download using Python version specifiers: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | import torch | ||||||
|  | from kernels import get_kernel | ||||||
|  |  | ||||||
|  | activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0") | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It | ||||||
|  | is strongly recommended to specify a version bound, since a kernel author | ||||||
|  | might push incompatible changes to the `main` branch. | ||||||
|  |  | ||||||
|  | ## Checking Kernel Availability | ||||||
|  |  | ||||||
|  | You can check if a specific kernel is available for your environment: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from kernels import has_kernel | ||||||
|  |  | ||||||
|  | # Check if kernel is available for current environment | ||||||
|  | is_available = has_kernel("kernels-community/activation") | ||||||
|  | print(f"Kernel available: {is_available}") | ||||||
|  | ``` | ||||||
							
								
								
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | |||||||
|  | # Kernels CLI Reference | ||||||
|  |  | ||||||
|  | ## Main Functions | ||||||
|  |  | ||||||
|  | ### kernels check | ||||||
|  |  | ||||||
|  | You can use `kernels check` to test compliance of a kernel on the Hub. | ||||||
|  | This currently checks that the kernel: | ||||||
|  |  | ||||||
|  | - Supports the currently-required Python ABI version. | ||||||
|  | - Works on supported operating system versions. | ||||||
|  |  | ||||||
|  | For example: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | $ kernels check kernels-community/flash-attn3 | ||||||
|  | Checking variant: torch28-cxx11-cu128-aarch64-linux | ||||||
|  |   🐍 Python ABI 3.9 compatible | ||||||
|  |   🐧 manylinux_2_28 compatible | ||||||
|  | [...] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### kernels to-wheel | ||||||
|  |  | ||||||
|  | We strongly recommend downloading kernels from the Hub using the `kernels` | ||||||
|  | package, since this comes with large [benefits](index.md) over using Python | ||||||
|  | wheels. That said, some projects may require deployment of kernels as | ||||||
|  | wheels. The `kernels` utility provides a simple solution to this. You can | ||||||
|  | convert any Hub kernel into a set of wheels with the `to-wheel` command: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | $ kernels to-wheel drbh/img2grey 1.1.2 | ||||||
|  | ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### kernels upload | ||||||
|  |  | ||||||
|  | Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload | ||||||
|  | your kernel builds to the Hub. To know the supported arguments run: `kernels upload -h`. | ||||||
|  |  | ||||||
|  | **Notes**: | ||||||
|  |  | ||||||
|  | - This will take care of creating a repository on the Hub with the `repo_id` provided. | ||||||
|  | - If a repo with the `repo_id` already exists and if it contains a `build` with the build variant | ||||||
|  |   being uploaded, it will attempt to delete the files existing under it. | ||||||
|  | - Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub. | ||||||
							
								
								
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | |||||||
|  | # Environment variables | ||||||
|  |  | ||||||
|  | ## `KERNELS_CACHE` | ||||||
|  |  | ||||||
|  | The directory to use as the local kernel cache. If not set, the cache | ||||||
|  | of the `huggingface_hub` package is used. | ||||||
|  |  | ||||||
|  | ## `DISABLE_KERNEL_MAPPING` | ||||||
|  |  | ||||||
|  | Disables kernel mappings for [`layers`](layers.md). | ||||||
							
								
								
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,41 @@ | |||||||
|  | # FAQ | ||||||
|  |  | ||||||
|  | ## Kernel layers | ||||||
|  |  | ||||||
|  | ### Why is the kernelization step needed as a separate step? | ||||||
|  |  | ||||||
|  | In earlier versions of `kernels`, a layer's `forward` method was replaced | ||||||
|  | by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. | ||||||
|  | The new `forward` would dispatch to a kernel based on the device type, | ||||||
|  | whether a model was training, etc. However, this approach was | ||||||
|  | fundamentally incompatible with `torch.compile` since it relied | ||||||
|  | on data-dependent branching. | ||||||
|  |  | ||||||
|  | To avoid branching, we have to make dispatch decisions ahead of time, | ||||||
|  | which is what the `kernelize` function does. | ||||||
|  |  | ||||||
|  | ### Why does kernelization only replace `forward` methods? | ||||||
|  |  | ||||||
|  | There are some other possible approaches. The first is to completely | ||||||
|  | replace existing layers by kernel layers. However, since this would | ||||||
|  | permit free-form layer classes, it would be much harder to validate | ||||||
|  | that layers are fully compatible with the layers that they are | ||||||
|  | replacing. For instance, they could have completely different member | ||||||
|  | variables. Besides that, we would also need to hold on to the original | ||||||
|  | layers, in case we need to revert to the base layers when the model | ||||||
|  | is `kernelize`d again with different options. | ||||||
|  |  | ||||||
|  | A second approach would be to make an auxiliary layer that wraps the | ||||||
|  | original layer and the kernel layer and dispatches to the kernel layer. | ||||||
|  | This wouldn't have the issues of the first approach, because kernel layers | ||||||
|  | could be similarly strict as they are now, and we would still have access | ||||||
|  | to the original layers when `kernelize`-ing the model again. However, | ||||||
|  | this would change the graph structure of the model and would break use | ||||||
|  | cases where programs access the model internals (e.g. | ||||||
|  | `model.layers[0].attention.query_weight`) or rely on the graph structure | ||||||
|  | in other ways. | ||||||
|  |  | ||||||
|  | The approach of `forward`-replacement is the least invasive, because | ||||||
|  | it preserves the original model graph. It is also reversible, since | ||||||
|  | even though the `forward` of a layer _instance_ might be replaced, | ||||||
|  | the corresponding class still has the original `forward`. | ||||||
							
								
								
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | # Kernels | ||||||
|  |  | ||||||
|  | <div align="center"> | ||||||
|  | <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||||
|  | </div> | ||||||
|  |  | ||||||
|  | The Kernel Hub allows Python libraries and applications to load compute | ||||||
|  | kernels directly from the [Hub](https://hf.co/). To support this kind | ||||||
|  | of dynamic loading, Hub kernels differ from traditional Python kernel | ||||||
|  | packages in that they are made to be: | ||||||
|  |  | ||||||
|  | - **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`. | ||||||
|  | - **Unique**: multiple versions of the same kernel can be loaded in the | ||||||
|  |   same Python process. | ||||||
|  | - **Compatible**: kernels must support all recent versions of Python and | ||||||
|  |   the different PyTorch build configurations (various CUDA versions | ||||||
|  |   and C++ ABIs). Furthermore, older C library versions must be supported. | ||||||
|  |  | ||||||
|  | You can [search for kernels](https://huggingface.co/models?other=kernel) on | ||||||
|  | the Hub. | ||||||
							
								
								
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | # Installation | ||||||
|  |  | ||||||
|  | Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | pip install kernels | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | # Using kernels in a Docker container | ||||||
|  |  | ||||||
|  | Build and run the reference `examples/basic.py` in a Docker container with the following commands: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||||
|  | docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||||
|  | ``` | ||||||
| @ -1,8 +1,11 @@ | |||||||
| # Kernel requirements | # Kernel requirements | ||||||
| 
 | 
 | ||||||
| Kernels on the Hub must fulfill the requirements outlined on this page. | Kernels on the Hub must fulfill the requirements outlined on this page. By | ||||||
|  | ensuring kernels are compliant, they can be used on a wide range of Linux | ||||||
|  | systems and Torch builds. | ||||||
|  | 
 | ||||||
| You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
| to build conforming kernels. | to build compliant kernels. | ||||||
| 
 | 
 | ||||||
| ## Directory layout | ## Directory layout | ||||||
| 
 | 
 | ||||||
| @ -10,52 +13,83 @@ A kernel repository on the Hub must contain a `build` directory. This | |||||||
| directory contains build variants of a kernel in the form of directories | directory contains build variants of a kernel in the form of directories | ||||||
| following the template | following the template | ||||||
| `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | ||||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently | For example `build/torch26-cxx98-cu118-x86_64-linux`. | ||||||
| recommended build variants are: |  | ||||||
| 
 | 
 | ||||||
| - `torch25-cxx11-cu118-x86_64-linux` | Each variant directory must contain a single directory with the same name | ||||||
| - `torch25-cxx11-cu121-x86_64-linux` |  | ||||||
| - `torch25-cxx11-cu124-x86_64-linux` |  | ||||||
| - `torch25-cxx98-cu118-x86_64-linux` |  | ||||||
| - `torch25-cxx98-cu121-x86_64-linux` |  | ||||||
| - `torch25-cxx98-cu124-x86_64-linux` |  | ||||||
| - `torch26-cxx11-cu118-x86_64-linux` |  | ||||||
| - `torch26-cxx11-cu124-x86_64-linux` |  | ||||||
| - `torch26-cxx11-cu126-x86_64-linux` |  | ||||||
| - `torch26-cxx98-cu118-x86_64-linux` |  | ||||||
| - `torch26-cxx98-cu124-x86_64-linux` |  | ||||||
| - `torch26-cxx98-cu126-x86_64-linux` |  | ||||||
| 
 |  | ||||||
| This list will be updated as new PyTorch versions are released. Kernels |  | ||||||
| that are in pure Python (e.g. Triton kernels) only need to provide a |  | ||||||
| single build variant: |  | ||||||
| 
 |  | ||||||
| - `torch-universal` |  | ||||||
| 
 |  | ||||||
| Each variant directory should contain a single directory with the same name |  | ||||||
| as the repository (replacing `-` by `_`). For instance, kernels in the | as the repository (replacing `-` by `_`). For instance, kernels in the | ||||||
| `kernels-community/activation` repository have a directories like | `kernels-community/activation` repository have a directories like | ||||||
| `build/<variant>/activation`. This directory | `build/<variant>/activation`. This directory | ||||||
| must be a Python package with an `__init__.py` file. | must be a Python package with an `__init__.py` file. | ||||||
| 
 | 
 | ||||||
|  | ## Build variants | ||||||
|  | 
 | ||||||
|  | A kernel can be compliant for a specific compute framework (e.g. CUDA) or | ||||||
|  | architecture (e.g. x86_64). For compliance with a compute framework and | ||||||
|  | architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md) | ||||||
|  | must be available for that combination. | ||||||
|  | 
 | ||||||
|  | ## Versioning | ||||||
|  | 
 | ||||||
|  | Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||||
|  | the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) | ||||||
|  | to resolve the version constraints. | ||||||
|  | 
 | ||||||
|  | We recommend using [semver](https://semver.org/) to version kernels. | ||||||
|  | 
 | ||||||
| ## Native Python module | ## Native Python module | ||||||
| 
 | 
 | ||||||
| Kernels will typically contain a native Python module with precompiled | Kernels will typically contain a native Python module with precompiled | ||||||
| compute kernels and bindings. This module must fulfill the following | compute kernels and bindings. This module must fulfill the requirements | ||||||
| requirements: | outlined in this section. For all operating systems, a kernel must not | ||||||
|  | have dynamic library dependencies outside: | ||||||
|  | 
 | ||||||
|  | - Torch; | ||||||
|  | - CUDA/ROCm libraries installed as dependencies of Torch. | ||||||
|  | 
 | ||||||
|  | ## Compatibility with torch.compile | ||||||
|  | 
 | ||||||
|  | The Kernel Hub also encourages to write the kernels in a `torch.compile` | ||||||
|  | compliant way. This helps to ensure that the kernels are compatible with | ||||||
|  | `torch.compile` without introducing any graph breaks and triggering  | ||||||
|  | recompilation which can limit the benefits of compilation. | ||||||
|  | 
 | ||||||
|  | [Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and  | ||||||
|  | recompilation triggers during `torch.compile`. | ||||||
|  | 
 | ||||||
|  | ### Linux | ||||||
| 
 | 
 | ||||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||||
|   for compatibility with Python 3.9 and later. |   for compatibility with Python 3.9 and later. | ||||||
| - Compatible with glibc 2.27 or later. This means that no symbols | - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). | ||||||
|   from later versions must be used. To archive this, the module should |   This means that the extension **must not** use symbols versions higher than: | ||||||
|   be built against this glibc version. **Warning:** libgcc must also be |   - GLIBC 2.28 | ||||||
|   built against glibc 2.27 to avoid leaking symbols. |   - GLIBCXX 3.4.24 | ||||||
| - No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols |   - CXXABI 1.3.11 | ||||||
|   must be static. |   - GCC 7.0.0 | ||||||
| - No dynamic library dependencies outside Torch or CUDA libraries |  | ||||||
|   installed as dependencies of Torch. |  | ||||||
| 
 | 
 | ||||||
| (These requirements will be updated as new PyTorch versions are released.) | These requirements can be checked with the ABI checker (see below). | ||||||
|  | 
 | ||||||
|  | ### macOS | ||||||
|  | 
 | ||||||
|  | - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||||
|  |   for compatibility with Python 3.9 and later. | ||||||
|  | - macOS deployment target 15.0. | ||||||
|  | - Metal 3.0 (`-std=metal3.0`). | ||||||
|  | 
 | ||||||
|  | The ABI3 requirement can be checked with the ABI checker (see below). | ||||||
|  | 
 | ||||||
|  | ### ABI checker | ||||||
|  | 
 | ||||||
|  | The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||||
|  | [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||||
|  | 
 | ||||||
|  | ```bash | ||||||
|  | 
 | ||||||
|  | $ cargo install kernel-abi-check | ||||||
|  | $ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so | ||||||
|  | 🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9 | ||||||
|  | ✅ No compatibility issues found | ||||||
|  | ``` | ||||||
| 
 | 
 | ||||||
| ## Torch extension | ## Torch extension | ||||||
| 
 | 
 | ||||||
| @ -98,10 +132,20 @@ requirements: | |||||||
| - The `forward` method has a signature that is compatible with the | - The `forward` method has a signature that is compatible with the | ||||||
|   `forward` method that it is extending. |   `forward` method that it is extending. | ||||||
| 
 | 
 | ||||||
|  | There are two exceptions to the _no class variables rule_: | ||||||
|  | 
 | ||||||
|  | 1. The `has_backward` variable can be used to indicate whether the layer has | ||||||
|  |    a backward pass implemented (`True` when absent). | ||||||
|  | 2. The `can_torch_compile` variable can be used to indicate whether the layer | ||||||
|  |    supports `torch.compile` (`False` when absent). | ||||||
|  | 
 | ||||||
| This is an example of a pure layer: | This is an example of a pure layer: | ||||||
| 
 | 
 | ||||||
| ```python | ```python | ||||||
| class SiluAndMul(nn.Module): | class SiluAndMul(nn.Module): | ||||||
|  |     # This layer does not implement backward. | ||||||
|  |     has_backward: bool = False | ||||||
|  | 
 | ||||||
|     def forward(self, x: torch.Tensor): |     def forward(self, x: torch.Tensor): | ||||||
|         d = x.shape[-1] // 2 |         d = x.shape[-1] // 2 | ||||||
|         output_shape = x.shape[:-1] + (d,) |         output_shape = x.shape[:-1] + (d,) | ||||||
							
								
								
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,323 @@ | |||||||
|  | # Layers | ||||||
|  |  | ||||||
|  | A kernel can provide layers in addition to kernel functions. A layer from | ||||||
|  | the Hub can replace the `forward` method of an existing layer for a certain | ||||||
|  | device type. This makes it possible to provide more performant kernels for | ||||||
|  | existing layers. | ||||||
|  |  | ||||||
|  | See [Kernel requirements](kernel-requirements.md) for more information on the | ||||||
|  | requirements of Hub layers. | ||||||
|  |  | ||||||
|  | ## Making a layer extensible with kernels from the hub | ||||||
|  |  | ||||||
|  | ### Using a decorator | ||||||
|  |  | ||||||
|  | A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||||
|  | decorator. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | @use_kernel_forward_from_hub("SiluAndMul") | ||||||
|  | class SiluAndMul(nn.Module): | ||||||
|  |     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         d = input.shape[-1] // 2 | ||||||
|  |         return F.silu(input[..., :d]) * input[..., d:] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The decorator does not change the behavior of the class -- it annotates | ||||||
|  | the class with the given name (here `SiluAndMul`). The `kernelize` function | ||||||
|  | described below uses this name to look up kernels for the layer. | ||||||
|  |  | ||||||
|  | ### External layers | ||||||
|  |  | ||||||
|  | An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||||
|  | decorator can be made extensible using the `replace_kernel_forward_from_hub` | ||||||
|  | function: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from somelibrary import SiluAndMul | ||||||
|  |  | ||||||
|  | replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Warning:** we strongly recommend using layers with a decorator, since | ||||||
|  | it signifies that the maintainer intends to keep the `forward` signature | ||||||
|  | compatible with layers from the hub. | ||||||
|  |  | ||||||
|  | ## Kernelizing a model | ||||||
|  |  | ||||||
|  | A model will not use Hub kernels by default, even if it contains extensible | ||||||
|  | layers. To enable the use of Hub kernels in the model, it needs to be | ||||||
|  | 'kernelized' using the `kernelize` function. This function traverses the | ||||||
|  | model graph and replaces the `forward` methods of extensible layers for which | ||||||
|  | Hub kernels are registered. `kernelize` can be used as follows: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | model = MyModel(...) | ||||||
|  | model = kernelize(model, mode=Mode.INFERENCE) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The `kernelize` function modifies the model in-place, the model itself is | ||||||
|  | returned as a convenience. The `mode` specifies that the model will be used | ||||||
|  | in inference. Similarly, you can ask `kernelize` to prepare the model for | ||||||
|  | training: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | model = MyModel(...) | ||||||
|  | model = kernelize(model, mode=Mode.TRAINING) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | A model that is kernelized for training can also be used for inference, but | ||||||
|  | not the other way around. If you want to change the mode of the kernelized | ||||||
|  | model, you can just run `kernelize` on the model again with the new mode. | ||||||
|  |  | ||||||
|  | If you want to compile a model with `torch.compile`, this should be indicated | ||||||
|  | in the mode as well. You can do this by combining `Mode.INFERENCE` or | ||||||
|  | `Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | model = MyModel(...) | ||||||
|  |  | ||||||
|  | # Inference | ||||||
|  | model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE) | ||||||
|  |  | ||||||
|  | # Training | ||||||
|  | model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Kernel device | ||||||
|  |  | ||||||
|  | Kernels can be registered per device type. For instance, separate `cuda` and | ||||||
|  | `metal` kernels could be registered for the name `SiluAndMul`. By default, | ||||||
|  | `kernelize` will try to infer the device type from the model's parameters. | ||||||
|  | You can pass the device type to `kernelize` if the device type cannot be | ||||||
|  | inferred (e.g. because the model has no parameters): | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | model = MyModel(...) | ||||||
|  | model = kernelize(model, device="cuda", mode=Mode.INFERENCE) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### Fallback `forward` | ||||||
|  |  | ||||||
|  | If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered | ||||||
|  | kernel does not support backward passes or `torch.compile` respectively, | ||||||
|  | `kernelize` will fall back to the original, non-kernelized, layer. You | ||||||
|  | can let `kernelize` raise an exception instead by using `use_fallback=False`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | model = MyModel(...) | ||||||
|  | model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This can be useful if you want to guarantee that Hub kernels are used. | ||||||
|  |  | ||||||
|  | ### Inspecting which kernels are used | ||||||
|  |  | ||||||
|  | The kernels that are used are logged at the `INFO` level by `kernelize`. | ||||||
|  | See the [Python logging](https://docs.python.org/3/library/logging.html) | ||||||
|  | documentation for information on how to configure logging. | ||||||
|  |  | ||||||
|  | ## Registering a hub kernel for a layer | ||||||
|  |  | ||||||
|  | `kernelize` relies on kernel mappings to find Hub kernels for layers. | ||||||
|  | Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on | ||||||
|  | the Hub. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |         ), | ||||||
|  |         "rocm": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | You can register such a mapping using `register_kernel_mapping`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This will register the kernel mapping in the current context, which is | ||||||
|  | normally global. It is recommended to scope the mapping to where it is | ||||||
|  | used with the `use_kernel_mapping` context manager: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | with use_kernel_mapping(kernel_layer_mapping): | ||||||
|  |     # Use the layer for which the mapping is applied. | ||||||
|  |     model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This ensures that the mapping is not active anymore outside the | ||||||
|  | `with`-scope. | ||||||
|  |  | ||||||
|  | ### Using version bounds | ||||||
|  |  | ||||||
|  | Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||||
|  | You can specify which version of the kernel to download using Python version | ||||||
|  | specifiers: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             version=">=0.0.4,<0.1.0", | ||||||
|  |         ), | ||||||
|  |         "rocm": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             version=">=0.0.4,<0.1.0", | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This will get the layer from latest kernel tagged `v0.0.z` where `z` is at | ||||||
|  | least 4. It is strongly recommended to specify a version bound, since a | ||||||
|  | kernel author might push incompatible changes to the `main` branch. | ||||||
|  |  | ||||||
|  | ### Registering kernels for specific modes | ||||||
|  |  | ||||||
|  | You might want to register two different kernels for a particular layer, | ||||||
|  | where one kernel is optimized for a specific mode. You can do so by | ||||||
|  | registering layer repositories for specific modes. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": { | ||||||
|  |           Mode.INFERENCE: LayerRepository( | ||||||
|  |               repo_id="kernels-community/activation-inference-optimized", | ||||||
|  |               layer_name="SiluAndMul", | ||||||
|  |           ), | ||||||
|  |           Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository( | ||||||
|  |               repo_id="kernels-community/activation-training-optimized", | ||||||
|  |               layer_name="SiluAndMul", | ||||||
|  |           ), | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The `kernelize` function will attempt to use the following registered | ||||||
|  | kernels for a given mode: | ||||||
|  |  | ||||||
|  | - `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` → | ||||||
|  |   `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||||
|  | - `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` → | ||||||
|  |   `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||||
|  | - `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||||
|  | - `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK` | ||||||
|  |  | ||||||
|  | `Mode.FALLBACK` is a special mode that is used when no other mode matches. It | ||||||
|  | is also used when a kernel is registered without a mode, as described in the | ||||||
|  | previous section. | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": { | ||||||
|  |             Mode.FALLBACK: LayerRepository( | ||||||
|  |                 repo_id="kernels-community/activation", | ||||||
|  |                 layer_name="SiluAndMul", | ||||||
|  |             ), | ||||||
|  |             Mode.INFERENCE: LayerRepository( | ||||||
|  |                 repo_id="kernels-community/activation-inference-optimized", | ||||||
|  |                 layer_name="SiluAndMul", | ||||||
|  |             ), | ||||||
|  |             Mode.TRAINING: LayerRepository( | ||||||
|  |                 repo_id="kernels-community/activation-training-optimized", | ||||||
|  |                 layer_name="SiluAndMul", | ||||||
|  |             ), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and | ||||||
|  | `Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel, | ||||||
|  | since the other kernels do not support `torch.compile`. | ||||||
|  |  | ||||||
|  | ### Registering kernels for specific CUDA capabilities | ||||||
|  |  | ||||||
|  | Some kernels only work with newer CUDA architectures. For instance, some | ||||||
|  | kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels` | ||||||
|  | supports registering layers for a range of CUDA capabilities. To do so, | ||||||
|  | you need to register the layer for a `Device` with type `cuda` and | ||||||
|  | set the supported range of CUDA capabilities with using `CUDAProperties`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         Device( | ||||||
|  |             type="cuda", | ||||||
|  |             properties=CUDAProperties( | ||||||
|  |                 min_capability=75, max_capability=89 | ||||||
|  |             ), | ||||||
|  |         ): LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |         ), | ||||||
|  |         Device( | ||||||
|  |             type="cuda", | ||||||
|  |             properties=CUDAProperties( | ||||||
|  |                 min_capability=90, max_capability=sys.maxsize | ||||||
|  |             ), | ||||||
|  |         ): LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation-hopper", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |         ), | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Capabilities behave as follows: | ||||||
|  |  | ||||||
|  | - The minimum and maximum capabilities are inclusive. | ||||||
|  | - When a new kernel is registered with the same min/max capabilities as | ||||||
|  |   an existing kernel, the new kernel will replace the old kernel. | ||||||
|  | - When there are multiple kernels that support a capability, the kernel | ||||||
|  |   with the smaller capability interval will be used. E.g. given: | ||||||
|  |   - `KernelA` with `min_capability=80` and `max_capability=89`; | ||||||
|  |   - `KernelB` with `min_capability=75` and `max_capability=89`; | ||||||
|  |   - `kernelize` runs on a system with capability 8.6. | ||||||
|  |  | ||||||
|  |   Then `KernelA` will be used because the interval 80..89 is smaller | ||||||
|  |   than 75..89. The motivation is that kernels with smaller ranges | ||||||
|  |   tend to be more optimized for a specific set of GPUs. **This behavior | ||||||
|  |   might still change in the future.** | ||||||
|  |  | ||||||
|  | ### Registering kernels for specific ROCm capabilities | ||||||
|  |  | ||||||
|  | Registering kernels for the ROCm architecture follows the exact same | ||||||
|  | pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict | ||||||
|  | a kernel to a range of ROCm capabilities. | ||||||
|  |  | ||||||
|  | ### Loading from a local repository for testing | ||||||
|  |  | ||||||
|  | The `LocalLayerRepository` class is provided to load a repository from | ||||||
|  | a local directory. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | with use_kernel_mapping( | ||||||
|  |     { | ||||||
|  |         "SiluAndMul": { | ||||||
|  |             "cuda": LocalLayerRepository( | ||||||
|  |                 repo_path="/home/daniel/kernels/activation", | ||||||
|  |                 package_name="activation", | ||||||
|  |                 layer_name="SiluAndMul", | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     inherit_mapping=False, | ||||||
|  | ): | ||||||
|  |     kernelize(linear, mode=Mode.INFERENCE) | ||||||
|  | ``` | ||||||
| @ -1,4 +1,4 @@ | |||||||
| # Locking kernel versions | # Locking kernel/layer versions | ||||||
| 
 | 
 | ||||||
| Projects that use `setuptools` can lock the kernel versions that should be | Projects that use `setuptools` can lock the kernel versions that should be | ||||||
| used. First specify the accepted versions in `pyproject.toml` and make | used. First specify the accepted versions in `pyproject.toml` and make | ||||||
| @ -13,7 +13,7 @@ build-backend = "setuptools.build_meta" | |||||||
| "kernels-community/activation" = ">=0.0.1" | "kernels-community/activation" = ">=0.0.1" | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with | Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with | ||||||
| the locked revisions. The locked revision will be used when loading a kernel with | the locked revisions. The locked revision will be used when loading a kernel with | ||||||
| `get_locked_kernel`: | `get_locked_kernel`: | ||||||
| 
 | 
 | ||||||
| @ -26,9 +26,27 @@ activation = get_locked_kernel("kernels-community/activation") | |||||||
| **Note:** the lock file is included in the package metadata, so it will only be visible | **Note:** the lock file is included in the package metadata, so it will only be visible | ||||||
| to `kernels` after doing an (editable or regular) installation of your project. | to `kernels` after doing an (editable or regular) installation of your project. | ||||||
| 
 | 
 | ||||||
|  | ## Locked kernel layers | ||||||
|  | 
 | ||||||
|  | Locking is also supported for kernel layers. To use locked layers, register them | ||||||
|  | with the `LockedLayerRepository` class: | ||||||
|  | 
 | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": LockedLayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  | 
 | ||||||
| ## Pre-downloading locked kernels | ## Pre-downloading locked kernels | ||||||
| 
 | 
 | ||||||
| Locked kernels can be pre-downloaded by running `kernel download .` in your | Locked kernels can be pre-downloaded by running `kernels download .` in your | ||||||
| project directory. This will download the kernels to your local Hugging Face | project directory. This will download the kernels to your local Hugging Face | ||||||
| Hub cache. | Hub cache. | ||||||
| 
 | 
 | ||||||
| @ -20,11 +20,11 @@ activation.gelu_fast(y, x) | |||||||
| print("Kernel successfully executed") | print("Kernel successfully executed") | ||||||
|  |  | ||||||
| # Check results | # Check results | ||||||
| expected = torch.tensor([ | expected = torch.tensor( | ||||||
|     [0.8408, 1.9551, 2.9961], |     [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||||
|     [4.0000, 5.0000, 6.0000], |     device="cuda:0", | ||||||
|     [7.0000, 8.0000, 9.0000] |     dtype=torch.float16, | ||||||
| ], device='cuda:0', dtype=torch.float16) | ) | ||||||
| assert torch.allclose(y, expected) | assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
| print("Calculated values are exact") | print("Calculated values are exact") | ||||||
|  | |||||||
							
								
								
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							| @ -51,30 +51,50 @@ | |||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
|     }, |     }, | ||||||
|     "nixpkgs": { |     "hf-nix": { | ||||||
|  |       "inputs": { | ||||||
|  |         "flake-compat": "flake-compat", | ||||||
|  |         "flake-utils": "flake-utils_2", | ||||||
|  |         "nixpkgs": "nixpkgs" | ||||||
|  |       }, | ||||||
|       "locked": { |       "locked": { | ||||||
|         "lastModified": 1737453259, |         "lastModified": 1754038838, | ||||||
|         "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", |         "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=", | ||||||
|         "owner": "danieldk", |         "owner": "huggingface", | ||||||
|         "repo": "nixpkgs", |         "repo": "hf-nix", | ||||||
|         "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", |         "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       }, |       }, | ||||||
|       "original": { |       "original": { | ||||||
|         "owner": "danieldk", |         "owner": "huggingface", | ||||||
|         "ref": "outlines-v0.1.4-tgi", |         "repo": "hf-nix", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "nixpkgs": { | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1752785354, | ||||||
|  |         "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=", | ||||||
|  |         "owner": "nixos", | ||||||
|         "repo": "nixpkgs", |         "repo": "nixpkgs", | ||||||
|  |         "rev": "d38025438a6ee456758dc03188ca6873a415463b", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "nixos", | ||||||
|  |         "repo": "nixpkgs", | ||||||
|  |         "rev": "d38025438a6ee456758dc03188ca6873a415463b", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
|     }, |     }, | ||||||
|     "root": { |     "root": { | ||||||
|       "inputs": { |       "inputs": { | ||||||
|         "flake-utils": "flake-utils", |         "flake-utils": "flake-utils", | ||||||
|  |         "hf-nix": "hf-nix", | ||||||
|         "nixpkgs": [ |         "nixpkgs": [ | ||||||
|           "tgi-nix", |           "hf-nix", | ||||||
|           "nixpkgs" |           "nixpkgs" | ||||||
|         ], |         ] | ||||||
|         "tgi-nix": "tgi-nix" |  | ||||||
|       } |       } | ||||||
|     }, |     }, | ||||||
|     "systems": { |     "systems": { | ||||||
| @ -106,27 +126,6 @@ | |||||||
|         "repo": "default", |         "repo": "default", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
|     }, |  | ||||||
|     "tgi-nix": { |  | ||||||
|       "inputs": { |  | ||||||
|         "flake-compat": "flake-compat", |  | ||||||
|         "flake-utils": "flake-utils_2", |  | ||||||
|         "nixpkgs": "nixpkgs" |  | ||||||
|       }, |  | ||||||
|       "locked": { |  | ||||||
|         "lastModified": 1741617161, |  | ||||||
|         "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=", |  | ||||||
|         "owner": "huggingface", |  | ||||||
|         "repo": "text-generation-inference-nix", |  | ||||||
|         "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925", |  | ||||||
|         "type": "github" |  | ||||||
|       }, |  | ||||||
|       "original": { |  | ||||||
|         "owner": "huggingface", |  | ||||||
|         "ref": "kernels-0.2.0", |  | ||||||
|         "repo": "text-generation-inference-nix", |  | ||||||
|         "type": "github" |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
|   }, |   }, | ||||||
|   "root": "root", |   "root": "root", | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								flake.nix
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								flake.nix
									
									
									
									
									
								
							| @ -1,7 +1,7 @@ | |||||||
| { | { | ||||||
|   inputs = { |   inputs = { | ||||||
|     tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0"; |     hf-nix.url = "github:huggingface/hf-nix"; | ||||||
|     nixpkgs.follows = "tgi-nix/nixpkgs"; |     nixpkgs.follows = "hf-nix/nixpkgs"; | ||||||
|     flake-utils.url = "github:numtide/flake-utils"; |     flake-utils.url = "github:numtide/flake-utils"; | ||||||
|   }; |   }; | ||||||
|   outputs = |   outputs = | ||||||
| @ -9,23 +9,28 @@ | |||||||
|       self, |       self, | ||||||
|       nixpkgs, |       nixpkgs, | ||||||
|       flake-utils, |       flake-utils, | ||||||
|       tgi-nix, |       hf-nix, | ||||||
|     }: |     }: | ||||||
|     flake-utils.lib.eachDefaultSystem ( |     flake-utils.lib.eachDefaultSystem ( | ||||||
|       system: |       system: | ||||||
|       let |       let | ||||||
|         pkgs = import nixpkgs { |         pkgs = import nixpkgs { | ||||||
|           inherit system; |           inherit system; | ||||||
|           inherit (tgi-nix.lib) config; |           config = hf-nix.lib.config system; | ||||||
|           overlays = [ |           overlays = [ | ||||||
|             tgi-nix.overlays.default |             hf-nix.overlays.default | ||||||
|           ]; |           ]; | ||||||
|         }; |         }; | ||||||
|       in |       in | ||||||
|       { |       { | ||||||
|         formatter = pkgs.nixfmt-rfc-style; |         formatter = pkgs.nixfmt-tree; | ||||||
|  |         packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {}; | ||||||
|         devShells = with pkgs; rec { |         devShells = with pkgs; rec { | ||||||
|           default = mkShell { |           default = mkShell { | ||||||
|  |             nativeBuildInputs = [ | ||||||
|  |               # For hf-doc-builder. | ||||||
|  |               nodejs | ||||||
|  |             ]; | ||||||
|             buildInputs = |             buildInputs = | ||||||
|               [ |               [ | ||||||
|                 black |                 black | ||||||
| @ -34,10 +39,15 @@ | |||||||
|                 ruff |                 ruff | ||||||
|               ] |               ] | ||||||
|               ++ (with python3.pkgs; [ |               ++ (with python3.pkgs; [ | ||||||
|  |                 docutils | ||||||
|                 huggingface-hub |                 huggingface-hub | ||||||
|  |                 (callPackage ./nix/kernel-abi-check.nix {}) | ||||||
|  |                 mktestdocs | ||||||
|                 pytest |                 pytest | ||||||
|                 pytest-benchmark |                 pytest-benchmark | ||||||
|  |                 pyyaml | ||||||
|                 torch |                 torch | ||||||
|  |                 types-pyyaml | ||||||
|                 venvShellHook |                 venvShellHook | ||||||
|               ]); |               ]); | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | { | ||||||
|  |   buildPythonPackage, | ||||||
|  |   fetchPypi, | ||||||
|  |   rustPlatform, | ||||||
|  | }: | ||||||
|  |  | ||||||
|  | buildPythonPackage rec { | ||||||
|  |   pname = "kernel-abi-check"; | ||||||
|  |   version = "0.6.2"; | ||||||
|  |  | ||||||
|  |   src = fetchPypi { | ||||||
|  |     inherit version; | ||||||
|  |     pname = "kernel_abi_check"; | ||||||
|  |     hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys="; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   cargoDeps = rustPlatform.fetchCargoVendor { | ||||||
|  |     inherit pname version src sourceRoot; | ||||||
|  |     hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A="; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   sourceRoot = "kernel_abi_check-${version}/bindings/python"; | ||||||
|  |  | ||||||
|  |   pyproject = true; | ||||||
|  |  | ||||||
|  |   nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ]; | ||||||
|  | } | ||||||
| @ -1,6 +1,6 @@ | |||||||
| [project] | [project] | ||||||
| name = "kernels" | name = "kernels" | ||||||
| version = "0.3.0" | version = "0.10.2.dev0" | ||||||
| description = "Download compute kernels" | description = "Download compute kernels" | ||||||
| authors = [ | authors = [ | ||||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, |   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||||
| @ -8,13 +8,14 @@ authors = [ | |||||||
|   { name = "David Holtz", email = "david@huggingface.co" }, |   { name = "David Holtz", email = "david@huggingface.co" }, | ||||||
|   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, |   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, | ||||||
| ] | ] | ||||||
|  | license = { text = "Apache-2.0" } | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| requires-python = ">= 3.9" | requires-python = ">= 3.9" | ||||||
| dependencies = [ | dependencies = [ | ||||||
|   "huggingface-hub>=0.26.3", |   "huggingface_hub>=0.26.0,<2.0", | ||||||
|   "packaging>=24.2", |   "packaging>=20.0", | ||||||
|   "tomli>=2.0.1; python_version<'3.11'", |   "pyyaml>=6", | ||||||
|   "torch>=2.5", |   "tomli>=2.0; python_version<'3.11'", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| [build-system] | [build-system] | ||||||
| @ -23,10 +24,20 @@ build-backend = "setuptools.build_meta" | |||||||
|  |  | ||||||
| [dependency-groups] | [dependency-groups] | ||||||
| dev = [ | dev = [ | ||||||
|   "mypy == 1.14.1", |   "mktestdocs>=0.2.5", | ||||||
|   "pytest >=8", |   "mypy>=1.15.0", | ||||||
|  |   "pytest>=8", | ||||||
|   # Whatever version is compatible with pytest. |   # Whatever version is compatible with pytest. | ||||||
|   "pytest-benchmark", |   "pytest-benchmark", | ||||||
|  |   "torch>=2.5", | ||||||
|  |   "types-pyyaml" | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [project.optional-dependencies] | ||||||
|  | abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"] | ||||||
|  | torch = ["torch"] | ||||||
|  | docs = [ | ||||||
|  |   "hf-doc-builder", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| [project.scripts] | [project.scripts] | ||||||
| @ -35,6 +46,9 @@ kernels = "kernels.cli:main" | |||||||
| [project.entry-points."egg_info.writers"] | [project.entry-points."egg_info.writers"] | ||||||
| "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||||
|  |  | ||||||
|  | [tool.isort] | ||||||
|  | profile = "black" | ||||||
|  | line_length = 119 | ||||||
|  |  | ||||||
| [tool.ruff] | [tool.ruff] | ||||||
| exclude = [ | exclude = [ | ||||||
| @ -61,4 +75,4 @@ line-length = 119 | |||||||
| # Ignored rules: | # Ignored rules: | ||||||
| # "E501" -> line length violation | # "E501" -> line length violation | ||||||
| lint.ignore = ["E501"] | lint.ignore = ["E501"] | ||||||
| lint.select = ["E", "F", "I", "W"] | lint.select = ["E", "F", "W"] | ||||||
|  | |||||||
							
								
								
									
										9
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | [pytest] | ||||||
|  | markers = | ||||||
|  |     cuda_only: marks tests that should only hosts with CUDA GPUs | ||||||
|  |     rocm_only: marks tests that should only run on hosts with ROCm GPUs | ||||||
|  |     darwin_only: marks tests that should only run on macOS | ||||||
|  |     xpu_only: marks tests that should only run on hosts with Intel XPUs | ||||||
|  |     npu_only: marks tests that should only run on Ascend NPUs | ||||||
|  |     token: enable tests that require a write token | ||||||
|  |     is_staging_test: Marks tests that should only run on a staging environment | ||||||
| @ -1,23 +1,46 @@ | |||||||
|  | import importlib.metadata | ||||||
|  |  | ||||||
|  | __version__ = importlib.metadata.version("kernels") | ||||||
|  |  | ||||||
| from kernels.layer import ( | from kernels.layer import ( | ||||||
|  |     CUDAProperties, | ||||||
|     Device, |     Device, | ||||||
|     LayerRepository, |     LayerRepository, | ||||||
|  |     LocalLayerRepository, | ||||||
|  |     LockedLayerRepository, | ||||||
|  |     Mode, | ||||||
|  |     kernelize, | ||||||
|     register_kernel_mapping, |     register_kernel_mapping, | ||||||
|  |     replace_kernel_forward_from_hub, | ||||||
|     use_kernel_forward_from_hub, |     use_kernel_forward_from_hub, | ||||||
|  |     use_kernel_mapping, | ||||||
| ) | ) | ||||||
| from kernels.utils import ( | from kernels.utils import ( | ||||||
|     get_kernel, |     get_kernel, | ||||||
|  |     get_local_kernel, | ||||||
|     get_locked_kernel, |     get_locked_kernel, | ||||||
|  |     has_kernel, | ||||||
|     install_kernel, |     install_kernel, | ||||||
|     load_kernel, |     load_kernel, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "get_kernel", |     "__version__", | ||||||
|     "get_locked_kernel", |     "CUDAProperties", | ||||||
|     "load_kernel", |  | ||||||
|     "install_kernel", |  | ||||||
|     "use_kernel_forward_from_hub", |  | ||||||
|     "register_kernel_mapping", |  | ||||||
|     "LayerRepository", |  | ||||||
|     "Device", |     "Device", | ||||||
|  |     "LayerRepository", | ||||||
|  |     "LocalLayerRepository", | ||||||
|  |     "LockedLayerRepository", | ||||||
|  |     "Mode", | ||||||
|  |     "get_kernel", | ||||||
|  |     "get_local_kernel", | ||||||
|  |     "get_locked_kernel", | ||||||
|  |     "has_kernel", | ||||||
|  |     "install_kernel", | ||||||
|  |     "kernelize", | ||||||
|  |     "load_kernel", | ||||||
|  |     "register_kernel_mapping", | ||||||
|  |     "replace_kernel_forward_from_hub", | ||||||
|  |     "use_kernel_forward_from_hub", | ||||||
|  |     "use_kernel_mapping", | ||||||
| ] | ] | ||||||
|  | |||||||
							
								
								
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,200 @@ | |||||||
|  | # AVL-balanced interval trees. We could use the intervaltree | ||||||
|  | # packages, but it seems unmaintained and does not have type | ||||||
|  | # annotations. | ||||||
|  |  | ||||||
|  | from typing import Generic, List, Optional, Tuple, TypeVar | ||||||
|  |  | ||||||
|  | T = TypeVar("T") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _Node(Generic[T]): | ||||||
|  |     """A node in the interval tree.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, start: int, end: int, data: T): | ||||||
|  |         self.start: int = start | ||||||
|  |         self.end: int = end | ||||||
|  |         self.data: T = data | ||||||
|  |         self.max_end: int = end | ||||||
|  |         self.left: Optional["_Node[T]"] = None | ||||||
|  |         self.right: Optional["_Node[T]"] = None | ||||||
|  |         self.height: int = 1 | ||||||
|  |  | ||||||
|  |     def __repr__(self) -> str: | ||||||
|  |         return f"Node({self.start}, {self.end})" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class IntervalTree(Generic[T]): | ||||||
|  |     """A data structure to hold and query (unique) intervals.""" | ||||||
|  |  | ||||||
|  |     root: Optional[_Node[T]] | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.root = None | ||||||
|  |  | ||||||
|  |     def insert(self, start: int, end: int, data: T) -> None: | ||||||
|  |         """ | ||||||
|  |         Inserts a new interval into the tree. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             start: The starting point of the interval. | ||||||
|  |             end: The ending point of the interval. | ||||||
|  |             data: The data associated with this interval. | ||||||
|  |         """ | ||||||
|  |         self.root = self._insert(self.root, start, end, data) | ||||||
|  |  | ||||||
|  |     def _get_height(self, node: Optional[_Node[T]]) -> int: | ||||||
|  |         if not node: | ||||||
|  |             return 0 | ||||||
|  |         return node.height | ||||||
|  |  | ||||||
|  |     def _get_balance(self, node: Optional[_Node[T]]) -> int: | ||||||
|  |         if not node: | ||||||
|  |             return 0 | ||||||
|  |         return self._get_height(node.left) - self._get_height(node.right) | ||||||
|  |  | ||||||
|  |     def _update_node_attributes(self, node: _Node[T]) -> None: | ||||||
|  |         node.height = 1 + max(self._get_height(node.left), self._get_height(node.right)) | ||||||
|  |         node.max_end = node.end | ||||||
|  |         if node.left: | ||||||
|  |             node.max_end = max(node.max_end, node.left.max_end) | ||||||
|  |         if node.right: | ||||||
|  |             node.max_end = max(node.max_end, node.right.max_end) | ||||||
|  |  | ||||||
|  |     def _right_rotate(self, y: _Node[T]) -> _Node[T]: | ||||||
|  |         """Performs a right rotation.""" | ||||||
|  |         x = y.left | ||||||
|  |         assert x is not None | ||||||
|  |         T2 = x.right | ||||||
|  |  | ||||||
|  |         x.right = y | ||||||
|  |         y.left = T2 | ||||||
|  |  | ||||||
|  |         self._update_node_attributes(y) | ||||||
|  |         self._update_node_attributes(x) | ||||||
|  |  | ||||||
|  |         return x | ||||||
|  |  | ||||||
|  |     def _left_rotate(self, x: _Node[T]) -> _Node[T]: | ||||||
|  |         """Performs a left rotation.""" | ||||||
|  |         y = x.right | ||||||
|  |         assert y is not None | ||||||
|  |         T2 = y.left | ||||||
|  |  | ||||||
|  |         y.left = x | ||||||
|  |         x.right = T2 | ||||||
|  |  | ||||||
|  |         self._update_node_attributes(x) | ||||||
|  |         self._update_node_attributes(y) | ||||||
|  |  | ||||||
|  |         return y | ||||||
|  |  | ||||||
|  |     def _insert( | ||||||
|  |         self, node: Optional[_Node[T]], start: int, end: int, data: T | ||||||
|  |     ) -> _Node[T]: | ||||||
|  |         """Recursive helper to insert a new node and balance the tree.""" | ||||||
|  |         if not node: | ||||||
|  |             return _Node(start, end, data) | ||||||
|  |  | ||||||
|  |         # Replace the data if the interval already exists. | ||||||
|  |         if start == node.start and end == node.end: | ||||||
|  |             node.data = data | ||||||
|  |             return node | ||||||
|  |  | ||||||
|  |         if start < node.start: | ||||||
|  |             node.left = self._insert(node.left, start, end, data) | ||||||
|  |         else: | ||||||
|  |             node.right = self._insert(node.right, start, end, data) | ||||||
|  |  | ||||||
|  |         self._update_node_attributes(node) | ||||||
|  |  | ||||||
|  |         balance = self._get_balance(node) | ||||||
|  |  | ||||||
|  |         # Left Left Case | ||||||
|  |         if balance > 1 and node.left and start < node.left.start: | ||||||
|  |             return self._right_rotate(node) | ||||||
|  |  | ||||||
|  |         # Right Right Case | ||||||
|  |         if balance < -1 and node.right and start >= node.right.start: | ||||||
|  |             return self._left_rotate(node) | ||||||
|  |  | ||||||
|  |         # Left Right Case | ||||||
|  |         if balance > 1 and node.left and start >= node.left.start: | ||||||
|  |             node.left = self._left_rotate(node.left) | ||||||
|  |             return self._right_rotate(node) | ||||||
|  |  | ||||||
|  |         # Right Left Case | ||||||
|  |         if balance < -1 and node.right and start < node.right.start: | ||||||
|  |             node.right = self._right_rotate(node.right) | ||||||
|  |             return self._left_rotate(node) | ||||||
|  |  | ||||||
|  |         return node | ||||||
|  |  | ||||||
|  |     def search(self, point: int) -> List[T]: | ||||||
|  |         """ | ||||||
|  |         Searches for all intervals that contain the given point. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             point: The point to search for. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             A list of data items from all matching intervals. | ||||||
|  |         """ | ||||||
|  |         results: List[T] = [] | ||||||
|  |         self._search(self.root, point, results) | ||||||
|  |         return results | ||||||
|  |  | ||||||
|  |     def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None: | ||||||
|  |         """Recursive helper to find all overlapping intervals.""" | ||||||
|  |         if node is None or point > node.max_end: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if node.left: | ||||||
|  |             self._search(node.left, point, results) | ||||||
|  |  | ||||||
|  |         if node.start <= point <= node.end: | ||||||
|  |             results.append(node.data) | ||||||
|  |  | ||||||
|  |         if point >= node.start and node.right: | ||||||
|  |             self._search(node.right, point, results) | ||||||
|  |  | ||||||
|  |     def find_smallest_interval(self, point: int) -> Optional[T]: | ||||||
|  |         """ | ||||||
|  |         Finds the item with the most specific (smallest) range for a given point. | ||||||
|  |  | ||||||
|  |         Args: | ||||||
|  |             point: The capability to look up. | ||||||
|  |  | ||||||
|  |         Returns: | ||||||
|  |             The data of the best-matching item, or None if no match is found. | ||||||
|  |         """ | ||||||
|  |         matches: List[Tuple[int, int, T]] = [] | ||||||
|  |         self._find_with_intervals(self.root, point, matches) | ||||||
|  |  | ||||||
|  |         if not matches: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # Return the smallest interval, sort by memory location when | ||||||
|  |         # there are multiple matches with the same interval size. This | ||||||
|  |         # is just to ensure that we can compare against a trivial | ||||||
|  |         # implementation in tests. | ||||||
|  |         best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2]))) | ||||||
|  |         return best_match[2] | ||||||
|  |  | ||||||
|  |     def _find_with_intervals( | ||||||
|  |         self, | ||||||
|  |         node: Optional[_Node[T]], | ||||||
|  |         point: int, | ||||||
|  |         results: List[Tuple[int, int, T]], | ||||||
|  |     ) -> None: | ||||||
|  |         """A modified search that collects interval ranges along with data.""" | ||||||
|  |         if node is None or point > node.max_end: | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if node.left: | ||||||
|  |             self._find_with_intervals(node.left, point, results) | ||||||
|  |  | ||||||
|  |         if node.start <= point <= node.end: | ||||||
|  |             results.append((node.start, node.end, node.data)) | ||||||
|  |  | ||||||
|  |         if point >= node.start and node.right: | ||||||
|  |             self._find_with_intervals(node.right, point, results) | ||||||
							
								
								
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,751 @@ | |||||||
|  | # coding=utf-8 | ||||||
|  | # Copyright 2021 The HuggingFace Team. All rights reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  |  | ||||||
|  | # Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py | ||||||
|  |  | ||||||
|  | import re | ||||||
|  |  | ||||||
|  | # Re pattern to catch things inside ` ` in :obj:`thing`. | ||||||
|  | _re_obj = re.compile(r":obj:`([^`]+)`") | ||||||
|  | # Re pattern to catch things inside ` ` in :math:`thing`. | ||||||
|  | _re_math = re.compile(r":math:`([^`]+)`") | ||||||
|  | # Re pattern to catch things between single backquotes. | ||||||
|  | _re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)") | ||||||
|  | # Re pattern to catch things between double backquotes. | ||||||
|  | _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)") | ||||||
|  | # Re pattern to catch things inside ` ` in :func/class/meth:`thing`. | ||||||
|  | _re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_rst_formatting(text): | ||||||
|  |     """ | ||||||
|  |     Convert rst syntax for formatting to markdown in a given text. | ||||||
|  |     """ | ||||||
|  |     # Remove :class:, :func: and :meth: markers. To code-links and put double backquotes | ||||||
|  |     # (to not be caught by the italic conversion). | ||||||
|  |     text = _re_func_class.sub(r"[``\1``]", text) | ||||||
|  |     # Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes | ||||||
|  |     # (to not be caught by the italic conversion). | ||||||
|  |     text = _re_obj.sub(r"``\1``", text) | ||||||
|  |     # Remove :math: markers. | ||||||
|  |     text = _re_math.sub(r"\\\\(\1\\\\)", text) | ||||||
|  |     # Convert content in single backquotes to italic. | ||||||
|  |     text = _re_single_backquotes.sub(r"\1*\2*\3", text) | ||||||
|  |     # Convert content in double backquotes to single backquotes. | ||||||
|  |     text = _re_double_backquotes.sub(r"\1`\2`\3", text) | ||||||
|  |     # Remove remaining :: | ||||||
|  |     text = re.sub(r"::\n", "", text) | ||||||
|  |  | ||||||
|  |     # Remove new lines inside blocks in backsticks as they will be kept. | ||||||
|  |     lines = text.split("\n") | ||||||
|  |     in_code = False | ||||||
|  |     text = None | ||||||
|  |     for line in lines: | ||||||
|  |         if in_code: | ||||||
|  |             splits = line.split("`") | ||||||
|  |             in_code = len(splits) > 1 and len(splits) % 2 == 1 | ||||||
|  |             if len(splits) == 1: | ||||||
|  |                 # Some forgotten lone backstick | ||||||
|  |                 text += "\n" + line | ||||||
|  |             else: | ||||||
|  |                 text += " " + line.lstrip() | ||||||
|  |         else: | ||||||
|  |             if text is not None: | ||||||
|  |                 text += "\n" + line | ||||||
|  |             else: | ||||||
|  |                 text = line | ||||||
|  |             splits = line.split("`") | ||||||
|  |             in_code = len(splits) % 2 == 0 | ||||||
|  |     return text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Re pattern to catch description and url in links of the form `description <url>`_. | ||||||
|  | _re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+") | ||||||
|  | # Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_. | ||||||
|  | _re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`") | ||||||
|  | # Re pattern to catch reference in links of the form :doc:`reference`. | ||||||
|  | _re_simple_doc = re.compile(r":doc:`([^`<]*)`") | ||||||
|  | # Re pattern to catch description and reference in links of the form :doc:`description <reference>`. | ||||||
|  | _re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`") | ||||||
|  | # Re pattern to catch reference in links of the form :ref:`reference`. | ||||||
|  | _re_simple_ref = re.compile(r":ref:`([^`<]*)`") | ||||||
|  | # Re pattern to catch description and reference in links of the form :ref:`description <reference>`. | ||||||
|  | _re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_rst_links(text, page_info): | ||||||
|  |     """ | ||||||
|  |     Convert the rst links in text to markdown. | ||||||
|  |     """ | ||||||
|  |     if "package_name" not in page_info: | ||||||
|  |         raise ValueError("`page_info` must contain at least the package_name.") | ||||||
|  |     package_name = page_info["package_name"] | ||||||
|  |     version = page_info.get("version", "main") | ||||||
|  |     language = page_info.get("language", "en") | ||||||
|  |     no_prefix = page_info.get("no_prefix", False) | ||||||
|  |  | ||||||
|  |     prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/" | ||||||
|  |     # Links of the form :doc:`page` | ||||||
|  |     text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text) | ||||||
|  |     # Links of the form :doc:`text <page>` | ||||||
|  |     text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text) | ||||||
|  |  | ||||||
|  |     if "page" in page_info and not no_prefix: | ||||||
|  |         page = str(page_info["page"]) | ||||||
|  |         if page.endswith(".html"): | ||||||
|  |             page = page[:-5] | ||||||
|  |         prefix = f"{prefix}{page}" | ||||||
|  |     else: | ||||||
|  |         prefix = "" | ||||||
|  |     # Refs of the form :ref:`page` | ||||||
|  |     text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text) | ||||||
|  |     # Refs of the form :ref:`text <page>` | ||||||
|  |     text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text) | ||||||
|  |  | ||||||
|  |     # Links with a prefix | ||||||
|  |     # TODO: when it exists, use the API to deal with prefix links properly. | ||||||
|  |     prefix = f"https://github.com/huggingface/{package_name}/tree/main/" | ||||||
|  |     text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text) | ||||||
|  |     # Other links | ||||||
|  |     text = _re_links.sub(r"[\1](\2)", text) | ||||||
|  |     # Relative links or Transformers links need to remove the .html | ||||||
|  |     if ( | ||||||
|  |         "(https://https://huggingface.co/" in text | ||||||
|  |         or re.search(r"\(\.+/", text) is not None | ||||||
|  |     ): | ||||||
|  |         text = text.replace(".html", "") | ||||||
|  |     return text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Re pattern that catches examples blocks of the form `Example::`. | ||||||
|  | _re_example = re.compile(r"^\s*(\S.*)::\s*$") | ||||||
|  | # Re pattern that catches rst blocks of the form `.. block_name::`. | ||||||
|  | _re_block = re.compile(r"^\s*\.\.\s+(\S+)::") | ||||||
|  | # Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`. | ||||||
|  | _re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def is_empty_line(line): | ||||||
|  |     return len(line) == 0 or line.isspace() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def find_indent(line): | ||||||
|  |     """ | ||||||
|  |     Returns the number of spaces that start a line indent. | ||||||
|  |     """ | ||||||
|  |     search = re.search(r"^(\s*)(?:\S|$)", line) | ||||||
|  |     if search is None: | ||||||
|  |         return 0 | ||||||
|  |     return len(search.groups()[0]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _re_rst_option = re.compile(r"^\s*:(\S+):(.*)$") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_special_chars(text): | ||||||
|  |     """ | ||||||
|  |     Converts { and < that have special meanings in MDX. | ||||||
|  |     """ | ||||||
|  |     text = text.replace("{", "&lcub;") | ||||||
|  |     # We don't want to replace those by the HTML code, so we temporarily set them at LTHTML | ||||||
|  |     text = re.sub( | ||||||
|  |         r"<(img|br|hr|Youtube)", r"LTHTML\1", text | ||||||
|  |     )  # html void elements with no closing counterpart | ||||||
|  |     _re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL) | ||||||
|  |     while _re_lt_html.search(text): | ||||||
|  |         text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text) | ||||||
|  |     text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text) | ||||||
|  |     text = text.replace("LTHTML", "<") | ||||||
|  |     return text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_options(block_content): | ||||||
|  |     """ | ||||||
|  |     Parses the option in some rst block content. | ||||||
|  |     """ | ||||||
|  |     block_lines = block_content.split("\n") | ||||||
|  |     block_indent = find_indent(block_lines[0]) | ||||||
|  |     current_option = None | ||||||
|  |     result = {} | ||||||
|  |     for line in block_lines: | ||||||
|  |         if _re_rst_option.search(line) is not None: | ||||||
|  |             current_option, value = _re_rst_option.search(line).groups() | ||||||
|  |             result[current_option] = value.lstrip() | ||||||
|  |         elif find_indent(line) > block_indent: | ||||||
|  |             result[current_option] += " " + line.lstrip() | ||||||
|  |  | ||||||
|  |     return result | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def apply_min_indent(text, min_indent): | ||||||
|  |     """ | ||||||
|  |     Make sure all lines in a text are have a minimum indentation. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         text (`str`): The text to treat. | ||||||
|  |         min_indent (`int`): The minimal indentation. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `str`: The processed text. | ||||||
|  |     """ | ||||||
|  |     lines = text.split("\n") | ||||||
|  |     idx = 0 | ||||||
|  |     while idx < len(lines): | ||||||
|  |         if is_empty_line(lines[idx]): | ||||||
|  |             idx += 1 | ||||||
|  |             continue | ||||||
|  |         indent = find_indent(lines[idx]) | ||||||
|  |         if indent < min_indent: | ||||||
|  |             while idx < len(lines) and ( | ||||||
|  |                 find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]) | ||||||
|  |             ): | ||||||
|  |                 if not is_empty_line(lines[idx]): | ||||||
|  |                     lines[idx] = " " * (min_indent - indent) + lines[idx] | ||||||
|  |                 idx += 1 | ||||||
|  |         else: | ||||||
|  |             idx += 1 | ||||||
|  |  | ||||||
|  |     return "\n".join(lines) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_rst_blocks(text, page_info): | ||||||
|  |     """ | ||||||
|  |     Converts rst special blocks (examples, notes) into MDX. | ||||||
|  |     """ | ||||||
|  |     if "package_name" not in page_info: | ||||||
|  |         raise ValueError("`page_info` must contain at least the package_name.") | ||||||
|  |     package_name = page_info["package_name"] | ||||||
|  |     version = page_info.get("version", "main") | ||||||
|  |     language = page_info.get("language", "en") | ||||||
|  |  | ||||||
|  |     lines = text.split("\n") | ||||||
|  |     idx = 0 | ||||||
|  |     new_lines = [] | ||||||
|  |     while idx < len(lines): | ||||||
|  |         block_type = None | ||||||
|  |         block_info = None | ||||||
|  |         if _re_block.search(lines[idx]) is not None: | ||||||
|  |             block_type = _re_block.search(lines[idx]).groups()[0] | ||||||
|  |             if _re_block_info.search(lines[idx]) is not None: | ||||||
|  |                 block_info = _re_block_info.search(lines[idx]).groups()[0] | ||||||
|  |         elif _re_example.search(lines[idx]) is not None: | ||||||
|  |             block_type = "code-block-example" | ||||||
|  |             block_info = "python" | ||||||
|  |             example_name = _re_example.search(lines[idx]).groups()[0] | ||||||
|  |             new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n") | ||||||
|  |         elif lines[idx].strip() == "..": | ||||||
|  |             block_type = "comment" | ||||||
|  |         elif lines[idx].strip() == "::": | ||||||
|  |             block_type = "code-block" | ||||||
|  |  | ||||||
|  |         if block_type is not None: | ||||||
|  |             block_indent = find_indent(lines[idx]) | ||||||
|  |             # Find the next nonempty line | ||||||
|  |             idx += 1 | ||||||
|  |             while idx < len(lines) and is_empty_line(lines[idx]): | ||||||
|  |                 idx += 1 | ||||||
|  |             # Grab the indent of the return line, this block will stop when we unindent under it (or has already) | ||||||
|  |             example_indent = ( | ||||||
|  |                 find_indent(lines[idx]) if idx < len(lines) else block_indent | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             if example_indent == block_indent: | ||||||
|  |                 block_content = "" | ||||||
|  |             else: | ||||||
|  |                 block_lines = [] | ||||||
|  |                 while idx < len(lines) and ( | ||||||
|  |                     is_empty_line(lines[idx]) | ||||||
|  |                     or find_indent(lines[idx]) >= example_indent | ||||||
|  |                 ): | ||||||
|  |                     block_lines.append(lines[idx][example_indent:]) | ||||||
|  |                     idx += 1 | ||||||
|  |                 block_content = "\n".join(block_lines) | ||||||
|  |  | ||||||
|  |             if block_type in ["code", "code-block"]: | ||||||
|  |                 prefix = "```" if block_info is None else f"```{block_info}" | ||||||
|  |                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n") | ||||||
|  |             elif block_type == "code-block-example": | ||||||
|  |                 prefix = f"<example>```{block_info}" | ||||||
|  |                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>") | ||||||
|  |             elif block_type == "note": | ||||||
|  |                 new_lines.append( | ||||||
|  |                     apply_min_indent( | ||||||
|  |                         f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |             elif block_type == "warning": | ||||||
|  |                 new_lines.append( | ||||||
|  |                     apply_min_indent( | ||||||
|  |                         "<Tip warning={true}>\n\n" | ||||||
|  |                         + f"{block_content.strip()}\n\n</Tip>\n", | ||||||
|  |                         block_indent, | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |             elif block_type == "raw": | ||||||
|  |                 new_lines.append(block_content.strip() + "\n") | ||||||
|  |             elif block_type == "math": | ||||||
|  |                 new_lines.append(f"$${block_content.strip()}$$\n") | ||||||
|  |             elif block_type == "comment": | ||||||
|  |                 new_lines.append(f"<!--{block_content.strip()}\n-->\n") | ||||||
|  |             elif block_type == "autofunction": | ||||||
|  |                 if block_info is not None: | ||||||
|  |                     new_lines.append(f"[[autodoc]] {block_info}\n") | ||||||
|  |             elif block_type == "autoclass": | ||||||
|  |                 if block_info is not None: | ||||||
|  |                     block = f"[[autodoc]] {block_info}\n" | ||||||
|  |                     options = parse_options(block_content) | ||||||
|  |                     if "special-members" in options: | ||||||
|  |                         special_members = options["special-members"].split(", ") | ||||||
|  |                         for special_member in special_members: | ||||||
|  |                             block += f"    - {special_member}\n" | ||||||
|  |                     if "members" in options: | ||||||
|  |                         members = options["members"] | ||||||
|  |                         if len(members) == 0: | ||||||
|  |                             block += "    - all\n" | ||||||
|  |                         else: | ||||||
|  |                             for member in members.split(", "): | ||||||
|  |                                 block += f"    - {member}\n" | ||||||
|  |                     new_lines.append(block) | ||||||
|  |             elif block_type == "image": | ||||||
|  |                 options = parse_options(block_content) | ||||||
|  |                 target = options.pop("target", None) | ||||||
|  |                 if block_info is not None: | ||||||
|  |                     options["src"] = block_info | ||||||
|  |                 else: | ||||||
|  |                     if target is None: | ||||||
|  |                         raise ValueError("Image source not defined.") | ||||||
|  |                     options["src"] = target | ||||||
|  |                 # Adapt path | ||||||
|  |                 options["src"] = options["src"].replace( | ||||||
|  |                     "/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/" | ||||||
|  |                 ) | ||||||
|  |                 html_code = " ".join( | ||||||
|  |                     [f'{key}="{value}"' for key, value in options.items()] | ||||||
|  |                 ) | ||||||
|  |                 new_lines.append(f"<img {html_code}/>\n") | ||||||
|  |  | ||||||
|  |             else: | ||||||
|  |                 new_lines.append( | ||||||
|  |                     f"{block_type},{block_info}\n{block_content.rstrip()}\n" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             new_lines.append(lines[idx]) | ||||||
|  |             idx += 1 | ||||||
|  |  | ||||||
|  |     return "\n".join(new_lines) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Re pattern that catches rst args blocks of the form `Parameters:`. | ||||||
|  | _re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$") | ||||||
|  | # Re pattern that catches return blocks of the form `Return:`. | ||||||
|  | _re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_return_line(line): | ||||||
|  |     """ | ||||||
|  |     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. | ||||||
|  |     """ | ||||||
|  |     splits_on_colon = line.split(":") | ||||||
|  |     idx = 1 | ||||||
|  |     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: | ||||||
|  |         idx += 2 | ||||||
|  |     if idx >= len(splits_on_colon): | ||||||
|  |         if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()): | ||||||
|  |             return line, "" | ||||||
|  |         return None, line | ||||||
|  |     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_raise_line(line): | ||||||
|  |     """ | ||||||
|  |     Split the raise line with format `SomeError some doc`. | ||||||
|  |     """ | ||||||
|  |     splits_on_colon = line.strip().split(" ") | ||||||
|  |     error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:]) | ||||||
|  |     if error_type and error_type[-1] == ":": | ||||||
|  |         error_type = error_type[:-1] | ||||||
|  |     return error_type, doc | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_arg_line(line): | ||||||
|  |     """ | ||||||
|  |     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. | ||||||
|  |     """ | ||||||
|  |     splits_on_colon = line.split(":") | ||||||
|  |     idx = 1 | ||||||
|  |     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: | ||||||
|  |         idx += 2 | ||||||
|  |     if idx >= len(splits_on_colon): | ||||||
|  |         return line, "" | ||||||
|  |     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class InvalidRstDocstringError(ValueError): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _re_parameters = re.compile( | ||||||
|  |     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL | ||||||
|  | ) | ||||||
|  | _re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def parse_rst_docstring(docstring): | ||||||
|  |     """ | ||||||
|  |     Parses a docstring written in rst, in particular the list of arguments and the return type. | ||||||
|  |     """ | ||||||
|  |     lines = docstring.split("\n") | ||||||
|  |     idx = 0 | ||||||
|  |     while idx < len(lines): | ||||||
|  |         # Parameters section | ||||||
|  |         if _re_args.search(lines[idx]) is not None: | ||||||
|  |             # Title of the section. | ||||||
|  |             lines[idx] = "<parameters>\n" | ||||||
|  |             # Find the next nonempty line | ||||||
|  |             idx += 1 | ||||||
|  |             while is_empty_line(lines[idx]): | ||||||
|  |                 idx += 1 | ||||||
|  |             # Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the | ||||||
|  |             # Returns or Raises block. | ||||||
|  |             param_indent = find_indent(lines[idx]) | ||||||
|  |             while ( | ||||||
|  |                 idx < len(lines) | ||||||
|  |                 and find_indent(lines[idx]) == param_indent | ||||||
|  |                 and _re_returns.search(lines[idx]) is None | ||||||
|  |             ): | ||||||
|  |                 intro, doc = split_arg_line(lines[idx]) | ||||||
|  |                 # Line starting with a > after indent indicate a "section title" in the parameters. | ||||||
|  |                 if intro.lstrip().startswith(">"): | ||||||
|  |                     lines[idx] = intro.lstrip() | ||||||
|  |                 else: | ||||||
|  |                     lines[idx] = ( | ||||||
|  |                         re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc | ||||||
|  |                     ) | ||||||
|  |                 idx += 1 | ||||||
|  |                 while idx < len(lines) and ( | ||||||
|  |                     is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent | ||||||
|  |                 ): | ||||||
|  |                     idx += 1 | ||||||
|  |             lines.insert(idx, "</parameters>\n") | ||||||
|  |             idx += 1 | ||||||
|  |  | ||||||
|  |         # Returns section | ||||||
|  |         elif _re_returns.search(lines[idx]) is not None: | ||||||
|  |             # tag is either `return` or `yield` | ||||||
|  |             tag = _re_returns.match(lines[idx]).group(1).lower() | ||||||
|  |             # Title of the section. | ||||||
|  |             lines[idx] = f"<{tag}s>\n" | ||||||
|  |             # Find the next nonempty line | ||||||
|  |             idx += 1 | ||||||
|  |             while is_empty_line(lines[idx]): | ||||||
|  |                 idx += 1 | ||||||
|  |  | ||||||
|  |             # Grab the indent of the return line, this block will stop when we unindent under it. | ||||||
|  |             return_indent = find_indent(lines[idx]) | ||||||
|  |             raised_errors = [] | ||||||
|  |             # The line may contain the return type. | ||||||
|  |             if tag in ["return", "yield"]: | ||||||
|  |                 return_type, return_description = split_return_line(lines[idx]) | ||||||
|  |                 lines[idx] = return_description | ||||||
|  |                 idx += 1 | ||||||
|  |                 while idx < len(lines) and ( | ||||||
|  |                     is_empty_line(lines[idx]) | ||||||
|  |                     or find_indent(lines[idx]) >= return_indent | ||||||
|  |                 ): | ||||||
|  |                     idx += 1 | ||||||
|  |             else: | ||||||
|  |                 while idx < len(lines) and find_indent(lines[idx]) == return_indent: | ||||||
|  |                     return_type, return_description = split_raise_line(lines[idx]) | ||||||
|  |                     raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type) | ||||||
|  |                     lines[idx] = "- " + raised_error + " -- " + return_description | ||||||
|  |                     md_link = _re_md_link.match(raised_error) | ||||||
|  |                     if md_link: | ||||||
|  |                         raised_error = md_link[1] | ||||||
|  |                         raised_error = re.sub( | ||||||
|  |                             r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error | ||||||
|  |                         ) | ||||||
|  |                     if raised_error not in raised_errors: | ||||||
|  |                         raised_errors.append(raised_error) | ||||||
|  |                     idx += 1 | ||||||
|  |                     while idx < len(lines) and ( | ||||||
|  |                         is_empty_line(lines[idx]) | ||||||
|  |                         or find_indent(lines[idx]) > return_indent | ||||||
|  |                     ): | ||||||
|  |                         idx += 1 | ||||||
|  |  | ||||||
|  |             lines.insert(idx, f"</{tag}s>\n") | ||||||
|  |             idx += 1 | ||||||
|  |  | ||||||
|  |             # Return block finished, we insert the return type if one was specified | ||||||
|  |             if tag in ["return", "yield"] and return_type is not None: | ||||||
|  |                 lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n" | ||||||
|  |             elif len(raised_errors) > 0: | ||||||
|  |                 # raised errors | ||||||
|  |                 lines[ | ||||||
|  |                     idx - 1 | ||||||
|  |                 ] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n" | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             idx += 1 | ||||||
|  |  | ||||||
|  |     result = "\n".join(lines) | ||||||
|  |  | ||||||
|  |     # combine multiple <parameters> blocks into one block | ||||||
|  |     if result.count("<parameters>") > 1: | ||||||
|  |         parameters_blocks = _re_parameters.findall(result) | ||||||
|  |         parameters_blocks = [pb[0].strip() for pb in parameters_blocks] | ||||||
|  |         parameters_str = "\n".join(parameters_blocks) | ||||||
|  |         result = _re_parameters.sub("", result) | ||||||
|  |         result += f"\n<parameters>{parameters_str}</parameters>\n" | ||||||
|  |  | ||||||
|  |     return result | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _re_list = re.compile(r"^\s*(-|\*|\d+\.)\s") | ||||||
|  | _re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def remove_indent(text): | ||||||
|  |     """ | ||||||
|  |     Remove indents in text, except the one linked to lists (or sublists). | ||||||
|  |     """ | ||||||
|  |     lines = text.split("\n") | ||||||
|  |     # List of indents to remember for nested lists | ||||||
|  |     current_indents = [] | ||||||
|  |     # List of new indents to remember for nested lists | ||||||
|  |     new_indents = [] | ||||||
|  |     is_inside_code = False | ||||||
|  |     code_indent = 0 | ||||||
|  |     for idx, line in enumerate(lines): | ||||||
|  |         # Line is an item in a list. | ||||||
|  |         if _re_list.search(line) is not None: | ||||||
|  |             indent = find_indent(line) | ||||||
|  |             # Is it a new list / new level of nestedness? | ||||||
|  |             if len(current_indents) == 0 or indent > current_indents[-1]: | ||||||
|  |                 current_indents.append(indent) | ||||||
|  |                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||||
|  |                 lines[idx] = " " * new_indent + line[indent:] | ||||||
|  |                 new_indent += len(_re_list.search(line).groups()[0]) + 1 | ||||||
|  |                 new_indents.append(new_indent) | ||||||
|  |             # Otherwise it's an existing level of list (current one, or previous one) | ||||||
|  |             else: | ||||||
|  |                 # Let's find the proper level of indentation | ||||||
|  |                 level = len(current_indents) - 1 | ||||||
|  |                 while level >= 0 and current_indents[level] != indent: | ||||||
|  |                     level -= 1 | ||||||
|  |                 current_indents = current_indents[: level + 1] | ||||||
|  |                 new_indents = new_indents[:level] | ||||||
|  |                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||||
|  |                 lines[idx] = " " * new_indent + line[indent:] | ||||||
|  |                 new_indent += len(_re_list.search(line).groups()[0]) + 1 | ||||||
|  |                 new_indents.append(new_indent) | ||||||
|  |  | ||||||
|  |         # Line is an autodoc, we keep the indent for the list just after if there is one. | ||||||
|  |         elif _re_autodoc.search(line) is not None: | ||||||
|  |             indent = find_indent(line) | ||||||
|  |             current_indents = [indent] | ||||||
|  |             new_indents = [4] | ||||||
|  |             lines[idx] = line.strip() | ||||||
|  |  | ||||||
|  |         # Deal with empty lines separately | ||||||
|  |         elif is_empty_line(line): | ||||||
|  |             lines[idx] = "" | ||||||
|  |  | ||||||
|  |         # Code blocks | ||||||
|  |         elif line.lstrip().startswith("```"): | ||||||
|  |             is_inside_code = not is_inside_code | ||||||
|  |             if is_inside_code: | ||||||
|  |                 code_indent = find_indent(line) | ||||||
|  |             lines[idx] = line[code_indent:] | ||||||
|  |         elif is_inside_code: | ||||||
|  |             lines[idx] = line[code_indent:] | ||||||
|  |  | ||||||
|  |         else: | ||||||
|  |             indent = find_indent(line) | ||||||
|  |             if len(current_indents) > 0 and indent > current_indents[-1]: | ||||||
|  |                 lines[idx] = " " * new_indents[-1] + line[indent:] | ||||||
|  |             elif len(current_indents) > 0: | ||||||
|  |                 # Let's find the proper level of indentation | ||||||
|  |                 level = len(current_indents) - 1 | ||||||
|  |                 while level >= 0 and current_indents[level] > indent: | ||||||
|  |                     level -= 1 | ||||||
|  |                 current_indents = current_indents[: level + 1] | ||||||
|  |                 if level >= 0: | ||||||
|  |                     if current_indents[level] < indent: | ||||||
|  |                         new_indents = new_indents[: level + 1] | ||||||
|  |                     else: | ||||||
|  |                         new_indents = new_indents[:level] | ||||||
|  |                     new_indent = 0 if len(new_indents) == 0 else new_indents[-1] | ||||||
|  |                     lines[idx] = " " * new_indent + line[indent:] | ||||||
|  |                     new_indents.append(new_indent) | ||||||
|  |                 else: | ||||||
|  |                     new_indents = [] | ||||||
|  |                     lines[idx] = line[indent:] | ||||||
|  |             else: | ||||||
|  |                 lines[idx] = line[indent:] | ||||||
|  |  | ||||||
|  |     return "\n".join(lines) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def base_rst_to_mdx(text, page_info, unindent=True): | ||||||
|  |     """ | ||||||
|  |     Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs. | ||||||
|  |     """ | ||||||
|  |     text = convert_rst_links(text, page_info) | ||||||
|  |     text = convert_special_chars(text) | ||||||
|  |     text = convert_rst_blocks(text, page_info) | ||||||
|  |     # Convert * in lists to - to avoid the formatting conversion treat them as bold. | ||||||
|  |     text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE) | ||||||
|  |     text = convert_rst_formatting(text) | ||||||
|  |     return remove_indent(text) if unindent else text | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_rst_docstring_to_mdx(docstring, page_info): | ||||||
|  |     """ | ||||||
|  |     Convert a docstring written in rst to mdx. | ||||||
|  |     """ | ||||||
|  |     text = parse_rst_docstring(docstring) | ||||||
|  |     return base_rst_to_mdx(text, page_info) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def process_titles(lines): | ||||||
|  |     """Converts rst titles to markdown titles.""" | ||||||
|  |     title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ") | ||||||
|  |     title_levels = {} | ||||||
|  |     new_lines = [] | ||||||
|  |     for line in lines: | ||||||
|  |         if ( | ||||||
|  |             len(new_lines) > 0 | ||||||
|  |             and len(line) >= len(new_lines[-1]) | ||||||
|  |             and len(set(line)) == 1 | ||||||
|  |             and line[0] in title_chars | ||||||
|  |             and line != "::" | ||||||
|  |         ): | ||||||
|  |             char = line[0] | ||||||
|  |             level = title_levels.get(char, len(title_levels) + 1) | ||||||
|  |             if level not in title_levels: | ||||||
|  |                 title_levels[char] = level | ||||||
|  |             new_lines[-1] = f"{'#' * level} {new_lines[-1]}" | ||||||
|  |         else: | ||||||
|  |             new_lines.append(line) | ||||||
|  |     return new_lines | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Matches lines with a pattern of a table new line in rst. | ||||||
|  | _re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$") | ||||||
|  | # Matches lines with a pattern of a table new line in rst, with a first column empty. | ||||||
|  | _re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$") | ||||||
|  | # Matches lines with a pattern of a first table line in rst. | ||||||
|  | _re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$") | ||||||
|  | # Re pattern that catches anchors of the type .. reference: | ||||||
|  | _re_anchor_section = re.compile(r"^\.\.\s+_(\S+):") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_pt_tf_code_blocks(text): | ||||||
|  |     """ | ||||||
|  |     Split PyTorch and TensorFlow specific block codes. | ||||||
|  |     """ | ||||||
|  |     lines = text.split("\n") | ||||||
|  |     new_lines = [] | ||||||
|  |     idx = 0 | ||||||
|  |     while idx < len(lines): | ||||||
|  |         if lines[idx].startswith("```"): | ||||||
|  |             code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []} | ||||||
|  |             is_pytorch = False | ||||||
|  |             is_tensorflow = False | ||||||
|  |             idx += 1 | ||||||
|  |             while idx < len(lines) and lines[idx].strip() != "```": | ||||||
|  |                 if "## PYTORCH CODE" in lines[idx]: | ||||||
|  |                     is_pytorch = True | ||||||
|  |                     is_tensorflow = False | ||||||
|  |                 elif "## TENSORFLOW CODE" in lines[idx]: | ||||||
|  |                     is_tensorflow = True | ||||||
|  |                     is_pytorch = False | ||||||
|  |                 elif is_pytorch: | ||||||
|  |                     code_lines["pytorch"].append(lines[idx]) | ||||||
|  |                 elif is_tensorflow: | ||||||
|  |                     code_lines["tensorflow"].append(lines[idx]) | ||||||
|  |                 else: | ||||||
|  |                     code_lines["common"].append(lines[idx]) | ||||||
|  |                 idx += 1 | ||||||
|  |             if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0: | ||||||
|  |                 block_lines = ["<frameworkcontent>", "<pt>"] | ||||||
|  |                 block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"]) | ||||||
|  |                 block_lines.extend(["```", "</pt>", "<tf>"]) | ||||||
|  |                 block_lines.extend( | ||||||
|  |                     code_lines["common"].copy() + code_lines["tensorflow"] | ||||||
|  |                 ) | ||||||
|  |                 block_lines.extend(["```", "</tf>", "</frameworkcontent>"]) | ||||||
|  |                 new_lines.extend(block_lines) | ||||||
|  |             else: | ||||||
|  |                 block_lines = code_lines["common"] + ["```"] | ||||||
|  |                 new_lines.extend(block_lines) | ||||||
|  |             idx += 1 | ||||||
|  |         else: | ||||||
|  |             new_lines.append(lines[idx]) | ||||||
|  |             idx += 1 | ||||||
|  |     return "\n".join(new_lines) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def convert_rst_to_mdx(rst_text, page_info, add_imports=True): | ||||||
|  |     """ | ||||||
|  |     Convert a document written in rst to mdx. | ||||||
|  |     """ | ||||||
|  |     lines = rst_text.split("\n") | ||||||
|  |     lines = process_titles(lines) | ||||||
|  |     if add_imports: | ||||||
|  |         new_lines = [ | ||||||
|  |             '<script lang="ts">', | ||||||
|  |             '	import Tip from "$lib/Tip.svelte";', | ||||||
|  |             '	import Youtube from "$lib/Youtube.svelte";', | ||||||
|  |             '	import Docstring from "$lib/Docstring.svelte";', | ||||||
|  |             '	import CodeBlock from "$lib/CodeBlock.svelte";', | ||||||
|  |             '	import CodeBlockFw from "$lib/CodeBlockFw.svelte";', | ||||||
|  |             '	import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";', | ||||||
|  |             '	import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";', | ||||||
|  |             '	import IconCopyLink from "$lib/IconCopyLink.svelte";', | ||||||
|  |             '	import FrameworkContent from "$lib/FrameworkContent.svelte";', | ||||||
|  |             '	import Markdown from "$lib/Markdown.svelte";', | ||||||
|  |             '	import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";', | ||||||
|  |             '	import Added from "$lib/Added.svelte";', | ||||||
|  |             '	import Changed from "$lib/Changed.svelte";', | ||||||
|  |             '	import Deprecated from "$lib/Deprecated.svelte";', | ||||||
|  |             '	import PipelineIcon from "$lib/PipelineIcon.svelte";', | ||||||
|  |             '	import PipelineTag from "$lib/PipelineTag.svelte";', | ||||||
|  |             "	", | ||||||
|  |             '	export let fw: "pt" | "tf"', | ||||||
|  |             "</script>", | ||||||
|  |             "<svelte:head>", | ||||||
|  |             '<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >', | ||||||
|  |             "</svelte:head>", | ||||||
|  |             "", | ||||||
|  |         ] | ||||||
|  |     else: | ||||||
|  |         new_lines = [] | ||||||
|  |     for line in lines: | ||||||
|  |         if _re_ignore_line_table.search(line) is not None: | ||||||
|  |             continue | ||||||
|  |         elif _re_ignore_line_table1.search(line) is not None: | ||||||
|  |             continue | ||||||
|  |         elif _re_sep_line_table.search(line) is not None: | ||||||
|  |             line = line.replace("=", "-").replace("+", "|") | ||||||
|  |         elif _re_anchor_section.search(line) is not None: | ||||||
|  |             anchor_name = _re_anchor_section.search(line).groups()[0] | ||||||
|  |             line = f"<a id='{anchor_name}'></a>" | ||||||
|  |         new_lines.append(line) | ||||||
|  |     text = "\n".join(new_lines) | ||||||
|  |  | ||||||
|  |     return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info)) | ||||||
							
								
								
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,52 @@ | |||||||
|  | from typing import Dict, Optional | ||||||
|  |  | ||||||
|  | from huggingface_hub import HfApi | ||||||
|  | from huggingface_hub.hf_api import GitRefInfo | ||||||
|  | from packaging.specifiers import SpecifierSet | ||||||
|  | from packaging.version import InvalidVersion, Version | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: | ||||||
|  |     """Get kernel versions that are available in the repository.""" | ||||||
|  |     versions = {} | ||||||
|  |     for tag in HfApi().list_repo_refs(repo_id).tags: | ||||||
|  |         if not tag.name.startswith("v"): | ||||||
|  |             continue | ||||||
|  |         try: | ||||||
|  |             versions[Version(tag.name[1:])] = tag | ||||||
|  |         except InvalidVersion: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |     return versions | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo: | ||||||
|  |     """ | ||||||
|  |     Get the locks for a kernel with the given version spec. | ||||||
|  |  | ||||||
|  |     The version specifier can be any valid Python version specifier: | ||||||
|  |     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers | ||||||
|  |     """ | ||||||
|  |     versions = _get_available_versions(repo_id) | ||||||
|  |     requirement = SpecifierSet(version_spec) | ||||||
|  |     accepted_versions = sorted(requirement.filter(versions.keys())) | ||||||
|  |  | ||||||
|  |     if len(accepted_versions) == 0: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"No version of `{repo_id}` satisfies requirement: {version_spec}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     return versions[accepted_versions[-1]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def select_revision_or_version( | ||||||
|  |     repo_id: str, revision: Optional[str], version: Optional[str] | ||||||
|  | ) -> str: | ||||||
|  |     if revision is not None and version is not None: | ||||||
|  |         raise ValueError("Either a revision or a version must be specified, not both.") | ||||||
|  |     elif revision is None and version is None: | ||||||
|  |         revision = "main" | ||||||
|  |     elif version is not None: | ||||||
|  |         revision = resolve_version_spec_as_ref(repo_id, version).target_commit | ||||||
|  |     assert revision is not None | ||||||
|  |     return revision | ||||||
							
								
								
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,142 @@ | |||||||
|  | import sys | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from huggingface_hub import snapshot_download | ||||||
|  | from kernel_abi_check import ( | ||||||
|  |     BinaryFormat, | ||||||
|  |     IncompatibleAbi3Symbol, | ||||||
|  |     IncompatibleMacOSVersion, | ||||||
|  |     IncompatibleManylinuxSymbol, | ||||||
|  |     MissingMacOSVersion, | ||||||
|  |     NonAbi3Symbol, | ||||||
|  |     ObjectFile, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | from kernels.utils import CACHE_DIR | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_kernel( | ||||||
|  |     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||||
|  | ): | ||||||
|  |     variants_path = ( | ||||||
|  |         Path( | ||||||
|  |             snapshot_download( | ||||||
|  |                 repo_id, | ||||||
|  |                 allow_patterns=["build/*"], | ||||||
|  |                 cache_dir=CACHE_DIR, | ||||||
|  |                 revision=revision, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         / "build" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     has_issues = False | ||||||
|  |     for variant_path in variants_path.iterdir(): | ||||||
|  |         if not variant_path.is_dir(): | ||||||
|  |             print( | ||||||
|  |                 f"⛔ `build/` must only contain directories, found: {variant_path.name}", | ||||||
|  |                 file=sys.stderr, | ||||||
|  |             ) | ||||||
|  |             has_issues = True | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         print(f"Checking variant: {variant_path.name}", file=sys.stderr) | ||||||
|  |  | ||||||
|  |         indent = 2 | ||||||
|  |  | ||||||
|  |         for dylib_path in variant_path.rglob("*.so"): | ||||||
|  |             print_with_indent( | ||||||
|  |                 indent, | ||||||
|  |                 f"Dynamic library {dylib_path.relative_to(variant_path)}:", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |             o = ObjectFile(dylib_path) | ||||||
|  |             has_issues |= check_abi3(o, python_abi, indent + 2) | ||||||
|  |  | ||||||
|  |             # TODO: also check operating system | ||||||
|  |             if o.format() == BinaryFormat.ELF: | ||||||
|  |                 has_issues |= check_manylinux(o, manylinux, indent + 2) | ||||||
|  |             elif o.format() == BinaryFormat.MACH_O: | ||||||
|  |                 has_issues |= check_macos(o, macos, indent + 2) | ||||||
|  |  | ||||||
|  |     if has_issues: | ||||||
|  |         sys.exit(1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool: | ||||||
|  |     has_issues = False | ||||||
|  |     violations = object_file.check_python_abi(python_abi) | ||||||
|  |     if violations != []: | ||||||
|  |         has_issues = True | ||||||
|  |         print_with_indent( | ||||||
|  |             indent, | ||||||
|  |             f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:", | ||||||
|  |         ) | ||||||
|  |         for violation in violations: | ||||||
|  |             if isinstance(violation, IncompatibleAbi3Symbol): | ||||||
|  |                 print_with_indent( | ||||||
|  |                     indent + 3, | ||||||
|  |                     f"{violation.name}: {violation.version_added}", | ||||||
|  |                 ) | ||||||
|  |             elif isinstance(violation, NonAbi3Symbol): | ||||||
|  |                 print_with_indent( | ||||||
|  |                     indent + 3, | ||||||
|  |                     f"{violation.name}", | ||||||
|  |                 ) | ||||||
|  |     else: | ||||||
|  |         print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible") | ||||||
|  |  | ||||||
|  |     return has_issues | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool: | ||||||
|  |     has_issues = False | ||||||
|  |     violations = object_file.check_macos(macos) | ||||||
|  |     if violations != []: | ||||||
|  |         has_issues = True | ||||||
|  |         print_with_indent( | ||||||
|  |             indent, | ||||||
|  |             f"⛔ Found incompatibility with macOS {macos}:", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for violation in violations: | ||||||
|  |             if isinstance(violation, MissingMacOSVersion): | ||||||
|  |                 print_with_indent( | ||||||
|  |                     indent + 3, | ||||||
|  |                     "shared library does not contain macOS version", | ||||||
|  |                 ) | ||||||
|  |             elif isinstance(violation, IncompatibleMacOSVersion): | ||||||
|  |                 print_with_indent( | ||||||
|  |                     indent + 3, | ||||||
|  |                     f"shared library requires macOS {violation.version}", | ||||||
|  |                 ) | ||||||
|  |     else: | ||||||
|  |         print_with_indent(indent, f"🍏 compatible with macOS {macos}") | ||||||
|  |  | ||||||
|  |     return has_issues | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool: | ||||||
|  |     has_issues = False | ||||||
|  |     violations = object_file.check_manylinux(manylinux) | ||||||
|  |     if violations != []: | ||||||
|  |         has_issues = True | ||||||
|  |         print_with_indent( | ||||||
|  |             indent, | ||||||
|  |             f"⛔ Found symbols that are incompatible with {manylinux}:", | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         for violation in violations: | ||||||
|  |             if isinstance(violation, IncompatibleManylinuxSymbol): | ||||||
|  |                 print_with_indent( | ||||||
|  |                     indent + 3, | ||||||
|  |                     f"{violation.name}_{violation.dep}: {violation.version}", | ||||||
|  |                 ) | ||||||
|  |     else: | ||||||
|  |         print_with_indent(indent, f"🐧 {manylinux} compatible") | ||||||
|  |  | ||||||
|  |     return has_issues | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def print_with_indent(indent: int, message: str): | ||||||
|  |     print(f"{' ' * indent}{message}", file=sys.stderr) | ||||||
| @ -4,10 +4,15 @@ import json | |||||||
| import sys | import sys | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | from huggingface_hub import create_repo, upload_folder | ||||||
|  |  | ||||||
| from kernels.compat import tomllib | from kernels.compat import tomllib | ||||||
| from kernels.lockfile import KernelLock, get_kernel_locks | from kernels.lockfile import KernelLock, get_kernel_locks | ||||||
| from kernels.utils import install_kernel, install_kernel_all_variants | from kernels.utils import install_kernel, install_kernel_all_variants | ||||||
|  |  | ||||||
|  | from .doc import generate_readme_for_kernel | ||||||
|  | from .wheel import build_variant_to_wheel | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     parser = argparse.ArgumentParser( |     parser = argparse.ArgumentParser( | ||||||
| @ -15,6 +20,31 @@ def main(): | |||||||
|     ) |     ) | ||||||
|     subparsers = parser.add_subparsers(required=True) |     subparsers = parser.add_subparsers(required=True) | ||||||
|  |  | ||||||
|  |     check_parser = subparsers.add_parser("check", help="Check a kernel for compliance") | ||||||
|  |     check_parser.add_argument("repo_id", type=str, help="The kernel repo ID") | ||||||
|  |     check_parser.add_argument( | ||||||
|  |         "--revision", | ||||||
|  |         type=str, | ||||||
|  |         default="main", | ||||||
|  |         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", | ||||||
|  |     ) | ||||||
|  |     check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0") | ||||||
|  |     check_parser.add_argument( | ||||||
|  |         "--manylinux", type=str, help="Manylinux version", default="manylinux_2_28" | ||||||
|  |     ) | ||||||
|  |     check_parser.add_argument( | ||||||
|  |         "--python-abi", type=str, help="Python ABI version", default="3.9" | ||||||
|  |     ) | ||||||
|  |     check_parser.set_defaults( | ||||||
|  |         func=lambda args: check_kernel( | ||||||
|  |             macos=args.macos, | ||||||
|  |             manylinux=args.manylinux, | ||||||
|  |             python_abi=args.python_abi, | ||||||
|  |             repo_id=args.repo_id, | ||||||
|  |             revision=args.revision, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     download_parser = subparsers.add_parser("download", help="Download locked kernels") |     download_parser = subparsers.add_parser("download", help="Download locked kernels") | ||||||
|     download_parser.add_argument( |     download_parser.add_argument( | ||||||
|         "project_dir", |         "project_dir", | ||||||
| @ -28,6 +58,29 @@ def main(): | |||||||
|     ) |     ) | ||||||
|     download_parser.set_defaults(func=download_kernels) |     download_parser.set_defaults(func=download_kernels) | ||||||
|  |  | ||||||
|  |     upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub") | ||||||
|  |     upload_parser.add_argument( | ||||||
|  |         "kernel_dir", | ||||||
|  |         type=Path, | ||||||
|  |         help="Directory of the kernel build", | ||||||
|  |     ) | ||||||
|  |     upload_parser.add_argument( | ||||||
|  |         "--repo_id", | ||||||
|  |         type=str, | ||||||
|  |         help="Repository ID to use to upload to the Hugging Face Hub", | ||||||
|  |     ) | ||||||
|  |     upload_parser.add_argument( | ||||||
|  |         "--branch", | ||||||
|  |         type=None, | ||||||
|  |         help="If set, the upload will be made to a particular branch of the provided `repo_id`.", | ||||||
|  |     ) | ||||||
|  |     upload_parser.add_argument( | ||||||
|  |         "--private", | ||||||
|  |         action="store_true", | ||||||
|  |         help="If the repository should be private.", | ||||||
|  |     ) | ||||||
|  |     upload_parser.set_defaults(func=upload_kernels) | ||||||
|  |  | ||||||
|     lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") |     lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") | ||||||
|     lock_parser.add_argument( |     lock_parser.add_argument( | ||||||
|         "project_dir", |         "project_dir", | ||||||
| @ -36,6 +89,47 @@ def main(): | |||||||
|     ) |     ) | ||||||
|     lock_parser.set_defaults(func=lock_kernels) |     lock_parser.set_defaults(func=lock_kernels) | ||||||
|  |  | ||||||
|  |     to_wheel_parser = subparsers.add_parser( | ||||||
|  |         "to-wheel", help="Convert a kernel to a wheel file" | ||||||
|  |     ) | ||||||
|  |     to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID") | ||||||
|  |     to_wheel_parser.add_argument("version", type=str, help="The kernel version") | ||||||
|  |     to_wheel_parser.add_argument( | ||||||
|  |         "--python-version", | ||||||
|  |         type=str, | ||||||
|  |         default="3.9", | ||||||
|  |         help="The minimum Python version. Must match the Python version that the kernel was compiled for.", | ||||||
|  |     ) | ||||||
|  |     to_wheel_parser.add_argument( | ||||||
|  |         "--manylinux-version", | ||||||
|  |         type=str, | ||||||
|  |         default="2.28", | ||||||
|  |         help="The manylinux version. Must match the manylinux version that the kernel was compiled for.", | ||||||
|  |     ) | ||||||
|  |     to_wheel_parser.set_defaults(func=kernels_to_wheel) | ||||||
|  |  | ||||||
|  |     # Add generate-readme subcommand parser | ||||||
|  |     generate_readme_parser = subparsers.add_parser( | ||||||
|  |         "generate-readme", | ||||||
|  |         help="Generate README snippets for a kernel's public functions", | ||||||
|  |     ) | ||||||
|  |     generate_readme_parser.add_argument( | ||||||
|  |         "repo_id", | ||||||
|  |         type=str, | ||||||
|  |         help="The kernel repo ID (e.g., kernels-community/activation)", | ||||||
|  |     ) | ||||||
|  |     generate_readme_parser.add_argument( | ||||||
|  |         "--revision", | ||||||
|  |         type=str, | ||||||
|  |         default="main", | ||||||
|  |         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", | ||||||
|  |     ) | ||||||
|  |     generate_readme_parser.set_defaults( | ||||||
|  |         func=lambda args: generate_readme_for_kernel( | ||||||
|  |             repo_id=args.repo_id, revision=args.revision | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     args.func(args) |     args.func(args) | ||||||
|  |  | ||||||
| @ -77,6 +171,24 @@ def download_kernels(args): | |||||||
|         sys.exit(1) |         sys.exit(1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def kernels_to_wheel(args): | ||||||
|  |     variants_path = install_kernel_all_variants( | ||||||
|  |         repo_id=args.repo_id, revision=f"v{args.version}" | ||||||
|  |     ) | ||||||
|  |     for variant_path in variants_path.iterdir(): | ||||||
|  |         if not variant_path.is_dir(): | ||||||
|  |             continue | ||||||
|  |         wheel_path = build_variant_to_wheel( | ||||||
|  |             manylinux_version=args.manylinux_version, | ||||||
|  |             python_version=args.python_version, | ||||||
|  |             repo_id=args.repo_id, | ||||||
|  |             version=args.version, | ||||||
|  |             variant_path=variant_path, | ||||||
|  |             wheel_dir=Path("."), | ||||||
|  |         ) | ||||||
|  |         print(f"☸️ {wheel_path.name}", file=sys.stderr) | ||||||
|  |  | ||||||
|  |  | ||||||
| def lock_kernels(args): | def lock_kernels(args): | ||||||
|     with open(args.project_dir / "pyproject.toml", "rb") as f: |     with open(args.project_dir / "pyproject.toml", "rb") as f: | ||||||
|         data = tomllib.load(f) |         data = tomllib.load(f) | ||||||
| @ -91,8 +203,57 @@ def lock_kernels(args): | |||||||
|         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) |         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def upload_kernels(args): | ||||||
|  |     kernel_dir = Path(args.kernel_dir).resolve() | ||||||
|  |     build_dir = kernel_dir / "build" | ||||||
|  |     if not kernel_dir.is_dir(): | ||||||
|  |         raise ValueError(f"{kernel_dir} is not a directory") | ||||||
|  |     if not build_dir.is_dir(): | ||||||
|  |         raise ValueError("Couldn't find `build` directory inside `kernel_dir`") | ||||||
|  |  | ||||||
|  |     repo_id = create_repo( | ||||||
|  |         repo_id=args.repo_id, private=args.private, exist_ok=True | ||||||
|  |     ).repo_id | ||||||
|  |  | ||||||
|  |     delete_patterns: set[str] = set() | ||||||
|  |     for build_variant in build_dir.iterdir(): | ||||||
|  |         if build_variant.is_dir(): | ||||||
|  |             delete_patterns.add(f"{build_variant.name}/**") | ||||||
|  |  | ||||||
|  |     upload_folder( | ||||||
|  |         repo_id=repo_id, | ||||||
|  |         folder_path=build_dir, | ||||||
|  |         revision=args.branch, | ||||||
|  |         path_in_repo="build", | ||||||
|  |         delete_patterns=list(delete_patterns), | ||||||
|  |         commit_message="Build uploaded using `kernels`.", | ||||||
|  |     ) | ||||||
|  |     print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.") | ||||||
|  |  | ||||||
|  |  | ||||||
| class _JSONEncoder(json.JSONEncoder): | class _JSONEncoder(json.JSONEncoder): | ||||||
|     def default(self, o): |     def default(self, o): | ||||||
|         if dataclasses.is_dataclass(o): |         if dataclasses.is_dataclass(o): | ||||||
|             return dataclasses.asdict(o) |             return dataclasses.asdict(o) | ||||||
|         return super().default(o) |         return super().default(o) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_kernel( | ||||||
|  |     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||||
|  | ): | ||||||
|  |     try: | ||||||
|  |         import kernels.check | ||||||
|  |     except ImportError: | ||||||
|  |         print( | ||||||
|  |             "`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check", | ||||||
|  |             file=sys.stderr, | ||||||
|  |         ) | ||||||
|  |         sys.exit(1) | ||||||
|  |  | ||||||
|  |     kernels.check.check_kernel( | ||||||
|  |         macos=macos, | ||||||
|  |         manylinux=manylinux, | ||||||
|  |         python_abi=python_abi, | ||||||
|  |         repo_id=repo_id, | ||||||
|  |         revision=revision, | ||||||
|  |     ) | ||||||
|  | |||||||
							
								
								
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,242 @@ | |||||||
|  | import inspect | ||||||
|  | import re | ||||||
|  | import sys | ||||||
|  | from types import ModuleType | ||||||
|  |  | ||||||
|  | import yaml | ||||||
|  |  | ||||||
|  | from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx | ||||||
|  | from .utils import get_kernel | ||||||
|  |  | ||||||
|  | _RE_PARAMETERS = re.compile( | ||||||
|  |     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL | ||||||
|  | ) | ||||||
|  | _RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL) | ||||||
|  | _RE_RETURNTYPE = re.compile( | ||||||
|  |     r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _extract_description_before_tags(docstring_mdx: str) -> str: | ||||||
|  |     """Extract the description part of a docstring before any tags.""" | ||||||
|  |     params_pos = docstring_mdx.find("<parameters>") | ||||||
|  |     returns_pos = docstring_mdx.find("<returns>") | ||||||
|  |     returntype_pos = docstring_mdx.find("<returntype>") | ||||||
|  |     positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1] | ||||||
|  |  | ||||||
|  |     if positions: | ||||||
|  |         first_tag_pos = min(positions) | ||||||
|  |         return docstring_mdx[:first_tag_pos].strip() | ||||||
|  |     else: | ||||||
|  |         return docstring_mdx.strip() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None: | ||||||
|  |     """Print the parameters section from a docstring.""" | ||||||
|  |     matches = _RE_PARAMETERS.findall(docstring_mdx) | ||||||
|  |     if matches: | ||||||
|  |         header = "#" * header_level | ||||||
|  |         print(f"\n{header} Parameters") | ||||||
|  |         for match in matches: | ||||||
|  |             print(f"\n{match[0].strip()}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _print_returns_section( | ||||||
|  |     docstring_mdx: str, *, context_name: str, header_level: int | ||||||
|  | ) -> None: | ||||||
|  |     """Print the returns section from a docstring.""" | ||||||
|  |     return_matches = _RE_RETURNS.findall(docstring_mdx) | ||||||
|  |     returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx) | ||||||
|  |  | ||||||
|  |     if return_matches or returntype_matches: | ||||||
|  |         header = "#" * header_level | ||||||
|  |         print(f"\n{header} Returns") | ||||||
|  |  | ||||||
|  |         if returntype_matches: | ||||||
|  |             if len(returntype_matches) > 1: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     f"More than one <returntype> tag found in docstring for {context_name}" | ||||||
|  |                 ) | ||||||
|  |             print(f"\n**Type**: {returntype_matches[0][0].strip()}") | ||||||
|  |  | ||||||
|  |         if return_matches: | ||||||
|  |             for match in return_matches: | ||||||
|  |                 print(f"\n{match[0].strip()}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_docstring(obj, use_dict_check: bool = False) -> str: | ||||||
|  |     """Get docstring from an object, with fallback to default message.""" | ||||||
|  |     # Check whether the class/method itself has docs and not just | ||||||
|  |     # the superclass. | ||||||
|  |     if use_dict_check: | ||||||
|  |         has_doc = obj.__dict__.get("__doc__", None) is not None | ||||||
|  |     else: | ||||||
|  |         has_doc = getattr(obj, "__doc__", None) is not None | ||||||
|  |  | ||||||
|  |     # We use inspect.getdoc because it does normalization. | ||||||
|  |     doc = inspect.getdoc(obj) | ||||||
|  |  | ||||||
|  |     return doc if has_doc and doc is not None else "No documentation available." | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _process_and_print_docstring( | ||||||
|  |     docstring: str, *, kernel_name: str, context_name: str, header_level: int | ||||||
|  | ) -> None: | ||||||
|  |     """Convert docstring to MDX and print description, parameters, and returns sections.""" | ||||||
|  |     docstring_mdx = convert_rst_docstring_to_mdx( | ||||||
|  |         docstring, page_info={"package_name": kernel_name} | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Print the description | ||||||
|  |     description = _extract_description_before_tags(docstring_mdx) | ||||||
|  |     print(f"\n{description}") | ||||||
|  |  | ||||||
|  |     # Print parameters and returns sections | ||||||
|  |     _print_parameters_section(docstring_mdx, header_level=header_level) | ||||||
|  |     _print_returns_section( | ||||||
|  |         docstring_mdx, context_name=context_name, header_level=header_level | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None: | ||||||
|  |     kernel_module = get_kernel(repo_id=repo_id, revision=revision) | ||||||
|  |     kernel_name = repo_id.split("/")[-1].replace("-", "_") | ||||||
|  |  | ||||||
|  |     generate_metadata(kernel_module) | ||||||
|  |     generate_kernel_doc(kernel_module, kernel_name) | ||||||
|  |     generate_function_doc(kernel_module, kernel_name) | ||||||
|  |     generate_layers_doc(kernel_module, kernel_name) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate_metadata(module: ModuleType) -> None: | ||||||
|  |     metadata = getattr(module, "__kernel_metadata__", {}) | ||||||
|  |     if "tags" not in metadata: | ||||||
|  |         metadata["tags"] = ["kernel"] | ||||||
|  |     else: | ||||||
|  |         if "kernel" not in metadata["tags"]: | ||||||
|  |             metadata["tags"].append("kernel") | ||||||
|  |  | ||||||
|  |     print("---") | ||||||
|  |     print(yaml.dump(metadata), end="") | ||||||
|  |     print("---") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None: | ||||||
|  |     docstring = module.__doc__.strip() if module.__doc__ is not None else None | ||||||
|  |     if docstring: | ||||||
|  |         title, rest = docstring.split("\n", 1) | ||||||
|  |         print(f"# {title.strip()}") | ||||||
|  |         print( | ||||||
|  |             f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||||
|  |     print("\n## Functions") | ||||||
|  |  | ||||||
|  |     # Track if we found any functions | ||||||
|  |     found_functions = False | ||||||
|  |  | ||||||
|  |     for name, func in inspect.getmembers(kernel_module, inspect.isfunction): | ||||||
|  |         # Do not include imported functions. | ||||||
|  |         if func.__module__ != kernel_module.__name__: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         # Exclude private functions. | ||||||
|  |         if name.startswith("_"): | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         found_functions = True | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             sig = inspect.signature(func) | ||||||
|  |             docstring = _get_docstring(func) | ||||||
|  |         except ValueError: | ||||||
|  |             print( | ||||||
|  |                 f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}", | ||||||
|  |                 file=sys.stderr, | ||||||
|  |             ) | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         print(f"\n### Function `{name}`") | ||||||
|  |         print(f"\n`{sig}`") | ||||||
|  |  | ||||||
|  |         _process_and_print_docstring( | ||||||
|  |             docstring, kernel_name=kernel_name, context_name=name, header_level=3 | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     if not found_functions: | ||||||
|  |         print("\nNo public top-level functions.") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None: | ||||||
|  |     # Check if layers module is available | ||||||
|  |     layers_module = getattr(kernel_module, "layers", None) | ||||||
|  |     if layers_module is None: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     print("\n## Layers") | ||||||
|  |  | ||||||
|  |     # Track if we found any classes | ||||||
|  |     found_classes = False | ||||||
|  |  | ||||||
|  |     for class_name, cls in inspect.getmembers(layers_module, inspect.isclass): | ||||||
|  |         # Exclude classes that were imported. | ||||||
|  |         if cls.__module__ != layers_module.__name__: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         found_classes = True | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             # Get docstring, but not from superclasses. | ||||||
|  |             class_docstring = _get_docstring(cls, use_dict_check=True) | ||||||
|  |         except Exception: | ||||||
|  |             print( | ||||||
|  |                 f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}", | ||||||
|  |                 file=sys.stderr, | ||||||
|  |             ) | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |         print(f"\n### Class `{class_name}`") | ||||||
|  |  | ||||||
|  |         # Always print class description (helper handles conversion and formatting) | ||||||
|  |         class_docstring_mdx = convert_rst_docstring_to_mdx( | ||||||
|  |             class_docstring, page_info={"package_name": kernel_name} | ||||||
|  |         ) | ||||||
|  |         description = _extract_description_before_tags(class_docstring_mdx) | ||||||
|  |         print(f"\n{description}") | ||||||
|  |  | ||||||
|  |         # Document methods | ||||||
|  |         print("\n#### Methods") | ||||||
|  |  | ||||||
|  |         for method_name, method in inspect.getmembers(cls, inspect.isfunction): | ||||||
|  |             # Note: also skip __init__, since extension layers cannot have a constructor. | ||||||
|  |             if method_name.startswith("_"): | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # Skip methods from superclasses. | ||||||
|  |             if method_name not in cls.__dict__: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             try: | ||||||
|  |                 sig = inspect.signature(method) | ||||||
|  |                 method_docstring = _get_docstring(method) | ||||||
|  |             except ValueError: | ||||||
|  |                 print( | ||||||
|  |                     f"Warning: Could not retrieve signature for {method_name} in {class_name}", | ||||||
|  |                     file=sys.stderr, | ||||||
|  |                 ) | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             print(f"\n##### Method `{method_name}`") | ||||||
|  |             print(f"\n`{sig}`") | ||||||
|  |  | ||||||
|  |             _process_and_print_docstring( | ||||||
|  |                 method_docstring, | ||||||
|  |                 kernel_name=kernel_name, | ||||||
|  |                 context_name=method_name, | ||||||
|  |                 header_level=6, | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     if not found_classes: | ||||||
|  |         print("\nNo layers defined.") | ||||||
							
								
								
									
										1205
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
							
						
						
									
										1205
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -4,10 +4,8 @@ from pathlib import Path | |||||||
| from typing import Dict, List, Tuple | from typing import Dict, List, Tuple | ||||||
|  |  | ||||||
| from huggingface_hub import HfApi | from huggingface_hub import HfApi | ||||||
| from huggingface_hub.hf_api import GitRefInfo |  | ||||||
| from packaging.specifiers import SpecifierSet |  | ||||||
| from packaging.version import InvalidVersion, Version |  | ||||||
|  |  | ||||||
|  | from kernels._versions import resolve_version_spec_as_ref | ||||||
| from kernels.compat import tomllib | from kernels.compat import tomllib | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -31,20 +29,6 @@ class KernelLock: | |||||||
|         return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) |         return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: |  | ||||||
|     """Get kernel versions that are available in the repository.""" |  | ||||||
|     versions = {} |  | ||||||
|     for tag in HfApi().list_repo_refs(repo_id).tags: |  | ||||||
|         if not tag.name.startswith("v"): |  | ||||||
|             continue |  | ||||||
|         try: |  | ||||||
|             versions[Version(tag.name[1:])] = tag |  | ||||||
|         except InvalidVersion: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|     return versions |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | ||||||
|     """ |     """ | ||||||
|     Get the locks for a kernel with the given version spec. |     Get the locks for a kernel with the given version spec. | ||||||
| @ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | |||||||
|     The version specifier can be any valid Python version specifier: |     The version specifier can be any valid Python version specifier: | ||||||
|     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers |     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers | ||||||
|     """ |     """ | ||||||
|     versions = _get_available_versions(repo_id) |     tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec) | ||||||
|     requirement = SpecifierSet(version_spec) |  | ||||||
|     accepted_versions = sorted(requirement.filter(versions.keys())) |  | ||||||
|  |  | ||||||
|     if len(accepted_versions) == 0: |  | ||||||
|         raise ValueError( |  | ||||||
|             f"No version of `{repo_id}` satisfies requirement: {version_spec}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     tag_for_newest = versions[accepted_versions[-1]] |  | ||||||
|  |  | ||||||
|     r = HfApi().repo_info( |     r = HfApi().repo_info( | ||||||
|         repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True |         repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True | ||||||
|  | |||||||
| @ -4,6 +4,7 @@ import importlib | |||||||
| import importlib.metadata | import importlib.metadata | ||||||
| import inspect | import inspect | ||||||
| import json | import json | ||||||
|  | import logging | ||||||
| import os | import os | ||||||
| import platform | import platform | ||||||
| import sys | import sys | ||||||
| @ -12,29 +13,71 @@ from pathlib import Path | |||||||
| from types import ModuleType | from types import ModuleType | ||||||
| from typing import Dict, List, Optional, Tuple | from typing import Dict, List, Optional, Tuple | ||||||
|  |  | ||||||
| from huggingface_hub import snapshot_download | from huggingface_hub import file_exists, snapshot_download | ||||||
| from packaging.version import parse | from packaging.version import parse | ||||||
|  |  | ||||||
|  | from kernels._versions import select_revision_or_version | ||||||
| from kernels.lockfile import KernelLock, VariantLock | from kernels.lockfile import KernelLock, VariantLock | ||||||
|  |  | ||||||
| CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None) |  | ||||||
|  | def _get_cache_dir() -> Optional[str]: | ||||||
|  |     """Returns the kernels cache directory.""" | ||||||
|  |     cache_dir = os.environ.get("HF_KERNELS_CACHE", None) | ||||||
|  |     if cache_dir is not None: | ||||||
|  |         logging.warning( | ||||||
|  |             "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead" | ||||||
|  |         ) | ||||||
|  |         return cache_dir | ||||||
|  |  | ||||||
|  |     return os.environ.get("KERNELS_CACHE", None) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | CACHE_DIR: Optional[str] = _get_cache_dir() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_privateuse_backend_name() -> Optional[str]: | ||||||
|  |     import torch | ||||||
|  |  | ||||||
|  |     if hasattr(torch._C, "_get_privateuse1_backend_name"): | ||||||
|  |         return torch._C._get_privateuse1_backend_name() | ||||||
|  |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_variant() -> str: | def build_variant() -> str: | ||||||
|     import torch |     import torch | ||||||
|  |  | ||||||
|     if torch.version.cuda is None: |     if torch.version.cuda is not None: | ||||||
|  |         cuda_version = parse(torch.version.cuda) | ||||||
|  |         compute_framework = f"cu{cuda_version.major}{cuda_version.minor}" | ||||||
|  |     elif torch.version.hip is not None: | ||||||
|  |         rocm_version = parse(torch.version.hip.split("-")[0]) | ||||||
|  |         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||||
|  |     elif torch.backends.mps.is_available(): | ||||||
|  |         compute_framework = "metal" | ||||||
|  |     elif torch.version.xpu is not None: | ||||||
|  |         version = torch.version.xpu | ||||||
|  |         compute_framework = f"xpu{version[0:4]}{version[5:6]}" | ||||||
|  |     elif _get_privateuse_backend_name() == "npu": | ||||||
|  |         from torch_npu.utils.collect_env import get_cann_version  # type: ignore[import-not-found] | ||||||
|  |  | ||||||
|  |         cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] | ||||||
|  |         compute_framework = f"cann{cann_major}{cann_minor}" | ||||||
|  |     else: | ||||||
|         raise AssertionError( |         raise AssertionError( | ||||||
|             "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled." |             "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     torch_version = parse(torch.__version__) |     torch_version = parse(torch.__version__) | ||||||
|     cuda_version = parse(torch.version.cuda) |  | ||||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" |  | ||||||
|     cpu = platform.machine() |     cpu = platform.machine() | ||||||
|     os = platform.system().lower() |     os = platform.system().lower() | ||||||
|  |  | ||||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" |     if os == "darwin": | ||||||
|  |         cpu = "aarch64" if cpu == "arm64" else cpu | ||||||
|  |         return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" | ||||||
|  |  | ||||||
|  |     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||||
|  |  | ||||||
|  |     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" | ||||||
|  |  | ||||||
|  |  | ||||||
| def universal_build_variant() -> str: | def universal_build_variant() -> str: | ||||||
| @ -69,7 +112,20 @@ def install_kernel( | |||||||
|     """ |     """ | ||||||
|     Download a kernel for the current environment to the cache. |     Download a kernel for the current environment to the cache. | ||||||
|  |  | ||||||
|     The output path is validated againt `hash` when set. |     The output path is validated against the hashes in `variant_locks` when provided. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         revision (`str`): | ||||||
|  |             The specific revision (branch, tag, or commit) to download. | ||||||
|  |         local_files_only (`bool`, *optional*, defaults to `False`): | ||||||
|  |             Whether to only use local files and not download from the Hub. | ||||||
|  |         variant_locks (`Dict[str, VariantLock]`, *optional*): | ||||||
|  |             Optional dictionary of variant locks for validation. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory. | ||||||
|     """ |     """ | ||||||
|     package_name = package_name_from_repo_id(repo_id) |     package_name = package_name_from_repo_id(repo_id) | ||||||
|     variant = build_variant() |     variant = build_variant() | ||||||
| @ -84,6 +140,23 @@ def install_kernel( | |||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     try: | ||||||
|  |         return _load_kernel_from_path(repo_path, package_name, variant_locks) | ||||||
|  |     except FileNotFoundError: | ||||||
|  |         # Redo with more specific error message. | ||||||
|  |         raise FileNotFoundError( | ||||||
|  |             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _load_kernel_from_path( | ||||||
|  |     repo_path: Path, | ||||||
|  |     package_name: str, | ||||||
|  |     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||||
|  | ) -> Tuple[str, Path]: | ||||||
|  |     variant = build_variant() | ||||||
|  |     universal_variant = universal_build_variant() | ||||||
|  |  | ||||||
|     variant_path = repo_path / "build" / variant |     variant_path = repo_path / "build" / variant | ||||||
|     universal_variant_path = repo_path / "build" / universal_variant |     universal_variant_path = repo_path / "build" / universal_variant | ||||||
|  |  | ||||||
| @ -102,7 +175,7 @@ def install_kernel( | |||||||
|  |  | ||||||
|     if not os.path.exists(module_init_path): |     if not os.path.exists(module_init_path): | ||||||
|         raise FileNotFoundError( |         raise FileNotFoundError( | ||||||
|             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" |             f"Kernel at path `{repo_path}` does not have build: {variant}" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     return package_name, variant_path |     return package_name, variant_path | ||||||
| @ -139,17 +212,128 @@ def install_kernel_all_variants( | |||||||
|     return repo_path / "build" |     return repo_path / "build" | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | def get_kernel( | ||||||
|  |     repo_id: str, revision: Optional[str] = None, version: Optional[str] = None | ||||||
|  | ) -> ModuleType: | ||||||
|  |     """ | ||||||
|  |     Load a kernel from the kernel hub. | ||||||
|  |  | ||||||
|  |     This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before) | ||||||
|  |     and then loads the kernel. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         revision (`str`, *optional*, defaults to `"main"`): | ||||||
|  |             The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. | ||||||
|  |         version (`str`, *optional*): | ||||||
|  |             The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. | ||||||
|  |             Cannot be used together with `revision`. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         import torch | ||||||
|  |         from kernels import get_kernel | ||||||
|  |  | ||||||
|  |         activation = get_kernel("kernels-community/activation") | ||||||
|  |         x = torch.randn(10, 20, device="cuda") | ||||||
|  |         out = torch.empty_like(x) | ||||||
|  |         result = activation.silu_and_mul(out, x) | ||||||
|  |         ``` | ||||||
|  |     """ | ||||||
|  |     revision = select_revision_or_version(repo_id, revision, version) | ||||||
|     package_name, package_path = install_kernel(repo_id, revision=revision) |     package_name, package_path = install_kernel(repo_id, revision=revision) | ||||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") |     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: | ||||||
|  |     """ | ||||||
|  |     Import a kernel from a local kernel repository path. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_path (`Path`): | ||||||
|  |             The local path to the kernel repository. | ||||||
|  |         package_name (`str`): | ||||||
|  |             The name of the package to import from the repository. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module. | ||||||
|  |     """ | ||||||
|  |     variant = build_variant() | ||||||
|  |     universal_variant = universal_build_variant() | ||||||
|  |  | ||||||
|  |     # Presume we were given the top level path of the kernel repository. | ||||||
|  |     for base_path in [repo_path, repo_path / "build"]: | ||||||
|  |         # Prefer the universal variant if it exists. | ||||||
|  |         for v in [universal_variant, variant]: | ||||||
|  |             package_path = base_path / v / package_name / "__init__.py" | ||||||
|  |             if package_path.exists(): | ||||||
|  |                 return import_from_path(package_name, package_path) | ||||||
|  |  | ||||||
|  |     # If we didn't find the package in the repo we may have a explicit | ||||||
|  |     # package path. | ||||||
|  |     package_path = repo_path / package_name / "__init__.py" | ||||||
|  |     if package_path.exists(): | ||||||
|  |         return import_from_path(package_name, package_path) | ||||||
|  |  | ||||||
|  |     raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def has_kernel( | ||||||
|  |     repo_id: str, revision: Optional[str] = None, version: Optional[str] = None | ||||||
|  | ) -> bool: | ||||||
|  |     """ | ||||||
|  |     Check whether a kernel build exists for the current environment (Torch version and compute framework). | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         revision (`str`, *optional*, defaults to `"main"`): | ||||||
|  |             The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. | ||||||
|  |         version (`str`, *optional*): | ||||||
|  |             The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. | ||||||
|  |             Cannot be used together with `revision`. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `bool`: `True` if a kernel is available for the current environment. | ||||||
|  |     """ | ||||||
|  |     revision = select_revision_or_version(repo_id, revision, version) | ||||||
|  |  | ||||||
|  |     package_name = package_name_from_repo_id(repo_id) | ||||||
|  |     variant = build_variant() | ||||||
|  |     universal_variant = universal_build_variant() | ||||||
|  |  | ||||||
|  |     if file_exists( | ||||||
|  |         repo_id, | ||||||
|  |         revision=revision, | ||||||
|  |         filename=f"build/{universal_variant}/{package_name}/__init__.py", | ||||||
|  |     ): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     return file_exists( | ||||||
|  |         repo_id, | ||||||
|  |         revision=revision, | ||||||
|  |         filename=f"build/{variant}/{package_name}/__init__.py", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||||
|     """ |     """ | ||||||
|     Get a pre-downloaded, locked kernel. |     Get a pre-downloaded, locked kernel. | ||||||
|  |  | ||||||
|     If `lockfile` is not specified, the lockfile will be loaded from the |     If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata. | ||||||
|     caller's package metadata. |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         lockfile (`Path`, *optional*): | ||||||
|  |             Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module. | ||||||
|     """ |     """ | ||||||
|     if lockfile is None: |     if lockfile is None: | ||||||
|         locked_sha = _get_caller_locked_kernel(repo_id) |         locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
| @ -194,7 +378,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | ||||||
|     """Get a kernel using a lock file.""" |     """ | ||||||
|  |     Get a kernel using a lock file. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         local_files_only (`bool`, *optional*, defaults to `False`): | ||||||
|  |             Whether to only use local files and not download from the Hub. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module. | ||||||
|  |     """ | ||||||
|     locked_sha = _get_caller_locked_kernel(repo_id) |     locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
|  |  | ||||||
|     if locked_sha is None: |     if locked_sha is None: | ||||||
|  | |||||||
							
								
								
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,186 @@ | |||||||
|  | import email.policy | ||||||
|  | import os | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from email.message import Message | ||||||
|  | from importlib.metadata import PackageNotFoundError, version | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     KERNELS_VERSION = version("kernels") | ||||||
|  | except PackageNotFoundError: | ||||||
|  |     KERNELS_VERSION = "unknown" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class Metadata: | ||||||
|  |     name: str | ||||||
|  |     version: str | ||||||
|  |     cuda_version: Optional[str] | ||||||
|  |     cxx_abi_version: Optional[str] | ||||||
|  |     torch_version: Optional[str] | ||||||
|  |     os: Optional[str] | ||||||
|  |     platform: Optional[str] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def is_universal(self) -> bool: | ||||||
|  |         return self.platform is None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_variant_to_wheel( | ||||||
|  |     repo_id: str, | ||||||
|  |     *, | ||||||
|  |     version: str, | ||||||
|  |     variant_path: Path, | ||||||
|  |     wheel_dir: Path, | ||||||
|  |     manylinux_version: str = "2.28", | ||||||
|  |     python_version: str = "3.9", | ||||||
|  | ) -> Path: | ||||||
|  |     """ | ||||||
|  |     Create a wheel file from the variant path. | ||||||
|  |     """ | ||||||
|  |     name = repo_id.split("/")[-1].replace("_", "-") | ||||||
|  |     metadata = extract_metadata(name, version, variant_path) | ||||||
|  |     return build_wheel( | ||||||
|  |         metadata, | ||||||
|  |         variant_path=variant_path, | ||||||
|  |         wheel_dir=wheel_dir, | ||||||
|  |         manylinux_version=manylinux_version, | ||||||
|  |         python_version=python_version, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata: | ||||||
|  |     """ | ||||||
|  |     Extract metadata from the variant path. | ||||||
|  |     """ | ||||||
|  |     if variant_path.name == "torch-universal": | ||||||
|  |         return Metadata( | ||||||
|  |             name=name, | ||||||
|  |             version=version, | ||||||
|  |             cuda_version=None, | ||||||
|  |             cxx_abi_version=None, | ||||||
|  |             torch_version=None, | ||||||
|  |             os=None, | ||||||
|  |             platform=None, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     if not variant_path.name.startswith("torch"): | ||||||
|  |         raise ValueError("Currently only conversion of Torch kernels is supported.") | ||||||
|  |  | ||||||
|  |     variant_parts = variant_path.name.removeprefix("torch").split("-") | ||||||
|  |     if len(variant_parts) != 5: | ||||||
|  |         raise ValueError(f"Invalid variant name: {variant_path.name}") | ||||||
|  |  | ||||||
|  |     torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}" | ||||||
|  |     cpp_abi_version = variant_parts[1].removeprefix("cxx") | ||||||
|  |     cuda_version = variant_parts[2].removeprefix("cu") | ||||||
|  |     platform = variant_parts[3].replace("-", "_") | ||||||
|  |     os = variant_parts[4] | ||||||
|  |  | ||||||
|  |     return Metadata( | ||||||
|  |         name=name, | ||||||
|  |         version=version, | ||||||
|  |         cuda_version=cuda_version, | ||||||
|  |         cxx_abi_version=cpp_abi_version, | ||||||
|  |         torch_version=torch_version, | ||||||
|  |         os=os, | ||||||
|  |         platform=platform, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_wheel( | ||||||
|  |     metadata: Metadata, | ||||||
|  |     *, | ||||||
|  |     variant_path: Path, | ||||||
|  |     wheel_dir: Path, | ||||||
|  |     manylinux_version: str = "2.28", | ||||||
|  |     python_version: str = "3.9", | ||||||
|  | ) -> Path: | ||||||
|  |     """ | ||||||
|  |     Build the wheel file. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         from wheel.wheelfile import WheelFile  # type: ignore | ||||||
|  |     except ImportError: | ||||||
|  |         raise ImportError( | ||||||
|  |             "The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     name = metadata.name.replace("-", "_") | ||||||
|  |     python_version_flat = python_version.replace(".", "") | ||||||
|  |  | ||||||
|  |     if metadata.is_universal: | ||||||
|  |         python_tag = f"py{python_version_flat}" | ||||||
|  |         abi_tag = "none" | ||||||
|  |         platform_tag = "any" | ||||||
|  |         wheel_filename = ( | ||||||
|  |             f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl" | ||||||
|  |         ) | ||||||
|  |         dist_info_dir_name = f"{name}-{metadata.version}.dist-info" | ||||||
|  |         root_is_purelib = "true" | ||||||
|  |         requires_dist_torch = "torch" | ||||||
|  |     else: | ||||||
|  |         python_tag = f"cp{python_version_flat}" | ||||||
|  |         abi_tag = "abi3" | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             metadata.torch_version is None | ||||||
|  |             or metadata.cuda_version is None | ||||||
|  |             or metadata.cxx_abi_version is None | ||||||
|  |             or metadata.os is None | ||||||
|  |             or metadata.platform is None | ||||||
|  |         ): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels." | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}" | ||||||
|  |  | ||||||
|  |         if metadata.os == "linux": | ||||||
|  |             platform_tag = ( | ||||||
|  |                 f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}" | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}" | ||||||
|  |  | ||||||
|  |         wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" | ||||||
|  |         dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info" | ||||||
|  |         root_is_purelib = "false" | ||||||
|  |         requires_dist_torch = f"torch=={metadata.torch_version}.*" | ||||||
|  |  | ||||||
|  |     wheel_path = wheel_dir / wheel_filename | ||||||
|  |  | ||||||
|  |     wheel_msg = Message(email.policy.compat32) | ||||||
|  |     wheel_msg.add_header("Wheel-Version", "1.0") | ||||||
|  |     wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})") | ||||||
|  |     wheel_msg.add_header("Root-Is-Purelib", root_is_purelib) | ||||||
|  |     wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}") | ||||||
|  |  | ||||||
|  |     metadata_msg = Message(email.policy.compat32) | ||||||
|  |     metadata_msg.add_header("Metadata-Version", "2.1") | ||||||
|  |     metadata_msg.add_header("Name", name) | ||||||
|  |     metadata_msg.add_header("Version", metadata.version) | ||||||
|  |     metadata_msg.add_header("Summary", f"{name} kernel") | ||||||
|  |     metadata_msg.add_header("Requires-Python", ">=3.9") | ||||||
|  |     metadata_msg.add_header("Requires-Dist", requires_dist_torch) | ||||||
|  |  | ||||||
|  |     source_pkg_dir = variant_path / name | ||||||
|  |  | ||||||
|  |     with WheelFile(wheel_path, "w") as wheel_file: | ||||||
|  |         for root, dirnames, filenames in os.walk(source_pkg_dir): | ||||||
|  |             for filename in filenames: | ||||||
|  |                 if filename.endswith(".pyc"): | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |                 abs_filepath = os.path.join(root, filename) | ||||||
|  |                 entry_name = os.path.relpath(abs_filepath, variant_path) | ||||||
|  |                 wheel_file.write(abs_filepath, entry_name) | ||||||
|  |  | ||||||
|  |         wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL") | ||||||
|  |         wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8")) | ||||||
|  |  | ||||||
|  |         metadata_path = os.path.join(dist_info_dir_name, "METADATA") | ||||||
|  |         wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8")) | ||||||
|  |  | ||||||
|  |     return wheel_path | ||||||
							
								
								
									
										46
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,46 @@ | |||||||
|  | import sys | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from kernels.utils import _get_privateuse_backend_name | ||||||
|  |  | ||||||
|  | has_cuda = ( | ||||||
|  |     hasattr(torch.version, "cuda") | ||||||
|  |     and torch.version.cuda is not None | ||||||
|  |     and torch.cuda.device_count() > 0 | ||||||
|  | ) | ||||||
|  | has_rocm = ( | ||||||
|  |     hasattr(torch.version, "hip") | ||||||
|  |     and torch.version.hip is not None | ||||||
|  |     and torch.cuda.device_count() > 0 | ||||||
|  | ) | ||||||
|  | has_xpu = ( | ||||||
|  |     hasattr(torch.version, "xpu") | ||||||
|  |     and torch.version.xpu is not None | ||||||
|  |     and torch.xpu.device_count() > 0 | ||||||
|  | ) | ||||||
|  | has_npu = _get_privateuse_backend_name() == "npu" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pytest_addoption(parser): | ||||||
|  |     parser.addoption( | ||||||
|  |         "--token", | ||||||
|  |         action="store_true", | ||||||
|  |         help="run tests that require a token with write permissions", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pytest_runtest_setup(item): | ||||||
|  |     if "cuda_only" in item.keywords and not has_cuda: | ||||||
|  |         pytest.skip("skipping CUDA-only test on host without CUDA") | ||||||
|  |     if "rocm_only" in item.keywords and not has_rocm: | ||||||
|  |         pytest.skip("skipping ROCm-only test on host without ROCm") | ||||||
|  |     if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): | ||||||
|  |         pytest.skip("skipping macOS-only test on non-macOS platform") | ||||||
|  |     if "xpu_only" in item.keywords and not has_xpu: | ||||||
|  |         pytest.skip("skipping XPU-only test on host without XPU") | ||||||
|  |     if "npu_only" in item.keywords and not has_npu: | ||||||
|  |         pytest.skip("skipping NPU-only test on host without NPU") | ||||||
|  |     if "token" in item.keywords and not item.config.getoption("--token"): | ||||||
|  |         pytest.skip("need --token option to run this test") | ||||||
| @ -1,54 +1,82 @@ | |||||||
| [ | [ | ||||||
|   { |   { | ||||||
|     "repo_id": "kernels-community/activation", |     "repo_id": "kernels-community/activation", | ||||||
|     "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0", |     "sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538", | ||||||
|     "variants": { |     "variants": { | ||||||
|       "torch25-cxx11-cu118-x86_64-linux": { |       "torch25-cxx11-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d", |         "hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx11-cu121-x86_64-linux": { |       "torch25-cxx11-cu121-x86_64-linux": { | ||||||
|         "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58", |         "hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx11-cu124-x86_64-linux": { |       "torch25-cxx11-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a", |         "hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu118-x86_64-linux": { |       "torch25-cxx98-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0", |         "hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu121-x86_64-linux": { |       "torch25-cxx98-cu121-x86_64-linux": { | ||||||
|         "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8", |         "hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu124-x86_64-linux": { |       "torch25-cxx98-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc", |         "hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu118-x86_64-linux": { |       "torch26-cxx11-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b", |         "hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu124-x86_64-linux": { |       "torch26-cxx11-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005", |         "hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx11-cu126-aarch64-linux": { | ||||||
|  |         "hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu126-x86_64-linux": { |       "torch26-cxx11-cu126-x86_64-linux": { | ||||||
|         "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc", |         "hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu118-x86_64-linux": { |       "torch26-cxx98-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f", |         "hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu124-x86_64-linux": { |       "torch26-cxx98-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf", |         "hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx98-cu126-aarch64-linux": { | ||||||
|  |         "hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu126-x86_64-linux": { |       "torch26-cxx98-cu126-x86_64-linux": { | ||||||
|         "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a", |         "hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch27-cxx11-cu118-x86_64-linux": { | ||||||
|  |         "hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch27-cxx11-cu126-aarch64-linux": { | ||||||
|  |         "hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch27-cxx11-cu126-x86_64-linux": { | ||||||
|  |         "hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch27-cxx11-cu128-aarch64-linux": { | ||||||
|  |         "hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch27-cxx11-cu128-x86_64-linux": { | ||||||
|  |         "hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  | |||||||
							
								
								
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,12 @@ | |||||||
|  | [ | ||||||
|  |   { | ||||||
|  |     "repo_id": "kernels-test/versions", | ||||||
|  |     "sha": "dc142fd6c9920c993d32be6358b78957c58681c3", | ||||||
|  |     "variants": { | ||||||
|  |       "torch-universal": { | ||||||
|  |         "hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | ] | ||||||
							
								
								
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,2 @@ | |||||||
|  | [tool.kernels.dependencies] | ||||||
|  | "kernels-test/versions" = ">=0.1.0,<0.2.0" | ||||||
| @ -1,7 +1,7 @@ | |||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| from kernels import get_kernel | from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @ -9,6 +9,25 @@ def kernel(): | |||||||
|     return get_kernel("kernels-community/activation") |     return get_kernel("kernels-community/activation") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def local_kernel_path(): | ||||||
|  |     package_name, path = install_kernel("kernels-community/activation", "main") | ||||||
|  |     # Path is the build variant path (build/torch-<...>), so the grandparent | ||||||
|  |     # is the kernel repository path. | ||||||
|  |     return package_name, path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def local_kernel(local_kernel_path): | ||||||
|  |     package_name, path = local_kernel_path | ||||||
|  |     return get_local_kernel(path.parent.parent, package_name) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def metal_kernel(): | ||||||
|  |     return get_kernel("kernels-test/relu-metal") | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def universal_kernel(): | def universal_kernel(): | ||||||
|     return get_kernel("kernels-community/triton-scaled-mm") |     return get_kernel("kernels-community/triton-scaled-mm") | ||||||
| @ -21,6 +40,7 @@ def device(): | |||||||
|     return "cuda" |     return "cuda" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_gelu_fast(kernel, device): | def test_gelu_fast(kernel, device): | ||||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) |     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
| @ -36,6 +56,100 @@ def test_gelu_fast(kernel, device): | |||||||
|     assert torch.allclose(y, expected) |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | def test_local_kernel(local_kernel, device): | ||||||
|  |     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||||
|  |     y = torch.empty_like(x) | ||||||
|  |  | ||||||
|  |     local_kernel.gelu_fast(y, x) | ||||||
|  |  | ||||||
|  |     expected = torch.tensor( | ||||||
|  |         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||||
|  |         device=device, | ||||||
|  |         dtype=torch.float16, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | def test_local_kernel_path_types(local_kernel_path, device): | ||||||
|  |     package_name, path = local_kernel_path | ||||||
|  |  | ||||||
|  |     # Top-level repo path | ||||||
|  |     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071 | ||||||
|  |     kernel = get_local_kernel(path.parent.parent, package_name) | ||||||
|  |     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||||
|  |     y = torch.empty_like(x) | ||||||
|  |  | ||||||
|  |     kernel.gelu_fast(y, x) | ||||||
|  |     expected = torch.tensor( | ||||||
|  |         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||||
|  |         device=device, | ||||||
|  |         dtype=torch.float16, | ||||||
|  |     ) | ||||||
|  |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |     # Build directory path | ||||||
|  |     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build | ||||||
|  |     kernel = get_local_kernel(path.parent.parent / "build", package_name) | ||||||
|  |     y = torch.empty_like(x) | ||||||
|  |     kernel.gelu_fast(y, x) | ||||||
|  |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |     # Explicit package path | ||||||
|  |     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux | ||||||
|  |     kernel = get_local_kernel(path, package_name) | ||||||
|  |     y = torch.empty_like(x) | ||||||
|  |     kernel.gelu_fast(y, x) | ||||||
|  |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.darwin_only | ||||||
|  | @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) | ||||||
|  | def test_relu_metal(metal_kernel, dtype): | ||||||
|  |     x = torch.arange(-10, 10, dtype=dtype, device="mps") | ||||||
|  |     y = metal_kernel.relu(x) | ||||||
|  |     assert torch.allclose(y, torch.relu(x)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "kernel_exists", | ||||||
|  |     [ | ||||||
|  |         ("kernels-community/activation", "main", True), | ||||||
|  |         ("kernels-community/triton-layer-norm", "main", True), | ||||||
|  |         # Repo only contains Torch 2.4 kernels (and we don't | ||||||
|  |         # support/test against this version). | ||||||
|  |         ("kernels-test/only-torch-2.4", "main", False), | ||||||
|  |         ("google-bert/bert-base-uncased", "87565a309", False), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_has_kernel(kernel_exists): | ||||||
|  |     repo_id, revision, kernel = kernel_exists | ||||||
|  |     assert has_kernel(repo_id, revision=revision) == kernel | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_version(): | ||||||
|  |     kernel = get_kernel("kernels-test/versions") | ||||||
|  |     assert kernel.version() == "0.2.0" | ||||||
|  |     kernel = get_kernel("kernels-test/versions", version="<1.0.0") | ||||||
|  |     assert kernel.version() == "0.2.0" | ||||||
|  |     kernel = get_kernel("kernels-test/versions", version="<0.2.0") | ||||||
|  |     assert kernel.version() == "0.1.1" | ||||||
|  |     kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0") | ||||||
|  |     assert kernel.version() == "0.1.1" | ||||||
|  |  | ||||||
|  |     with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): | ||||||
|  |         get_kernel("kernels-test/versions", version=">0.2.0") | ||||||
|  |  | ||||||
|  |     with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): | ||||||
|  |         kernel = get_kernel( | ||||||
|  |             "kernels-test/versions", revision="v0.1.0", version="<1.0.0" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_universal_kernel(universal_kernel): | def test_universal_kernel(universal_kernel): | ||||||
|     torch.manual_seed(0) |     torch.manual_seed(0) | ||||||
|     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") |     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||||
|  | |||||||
| @ -16,18 +16,21 @@ def device(): | |||||||
|     return "cuda" |     return "cuda" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_gelu_small(kernel, device, benchmark): | def test_gelu_small(kernel, device, benchmark): | ||||||
|     x = torch.randn(32, 32, dtype=torch.float16, device=device) |     x = torch.randn(32, 32, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|     benchmark(kernel.gelu_fast, y, x) |     benchmark(kernel.gelu_fast, y, x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_gelu_medium(kernel, device, benchmark): | def test_gelu_medium(kernel, device, benchmark): | ||||||
|     x = torch.randn(128, 128, dtype=torch.float16, device=device) |     x = torch.randn(128, 128, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|     benchmark(kernel.gelu_fast, y, x) |     benchmark(kernel.gelu_fast, y, x) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_gelu_large(kernel, device, benchmark): | def test_gelu_large(kernel, device, benchmark): | ||||||
|     x = torch.randn(512, 512, dtype=torch.float16, device=device) |     x = torch.randn(512, 512, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|  | |||||||
							
								
								
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,49 @@ | |||||||
|  | import inspect | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | from mktestdocs import check_docstring, get_codeblock_members | ||||||
|  |  | ||||||
|  | import kernels | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def all_public_functions(): | ||||||
|  |     function_list = inspect.getmembers(kernels, inspect.isfunction) | ||||||
|  |     return [func for _, func in function_list] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def all_public_classes(): | ||||||
|  |     class_list = inspect.getmembers(kernels, inspect.isclass) | ||||||
|  |     return [cls for _, cls in class_list] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def all_public_class_members(): | ||||||
|  |     members = get_codeblock_members(*all_public_classes()) | ||||||
|  |     return members | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "func", | ||||||
|  |     all_public_functions(), | ||||||
|  |     ids=lambda d: d.__name__, | ||||||
|  | ) | ||||||
|  | def test_func_docstring(func): | ||||||
|  |     check_docstring(obj=func) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "cls", | ||||||
|  |     all_public_classes(), | ||||||
|  |     ids=lambda d: d.__name__, | ||||||
|  | ) | ||||||
|  | def test_class_docstring(cls): | ||||||
|  |     check_docstring(obj=cls) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "member", all_public_class_members(), ids=lambda d: d.__qualname__ | ||||||
|  | ) | ||||||
|  | def test_member_docstring(member): | ||||||
|  |     check_docstring(member) | ||||||
							
								
								
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,230 @@ | |||||||
|  | import random | ||||||
|  | from typing import Generic, List, Optional, Tuple, TypeVar | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from kernels._interval_tree import IntervalTree, _Node | ||||||
|  |  | ||||||
|  | T = TypeVar("T") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SimpleIntervalStore(Generic[T]): | ||||||
|  |     """A simple O(n) implementation that stores intervals in a list.""" | ||||||
|  |  | ||||||
|  |     def __init__(self): | ||||||
|  |         self.intervals: List[Tuple[int, int, T]] = [] | ||||||
|  |  | ||||||
|  |     def insert(self, start: int, end: int, data: T) -> None: | ||||||
|  |         """Insert an interval into the store.""" | ||||||
|  |         # Replace data if the interval already exists. | ||||||
|  |         for i, (existing_start, existing_end, existing_data) in enumerate( | ||||||
|  |             self.intervals | ||||||
|  |         ): | ||||||
|  |             if existing_start == start and existing_end == end: | ||||||
|  |                 self.intervals[i] = (start, end, data) | ||||||
|  |                 return | ||||||
|  |  | ||||||
|  |         self.intervals.append((start, end, data)) | ||||||
|  |  | ||||||
|  |     def find_smallest_interval(self, point: int) -> Optional[T]: | ||||||
|  |         """Find the best match using linear search.""" | ||||||
|  |         matches = [] | ||||||
|  |         for start, end, data in self.intervals: | ||||||
|  |             if start <= point <= end: | ||||||
|  |                 matches.append((start, end, data)) | ||||||
|  |  | ||||||
|  |         if not matches: | ||||||
|  |             return None | ||||||
|  |  | ||||||
|  |         # Return the smallest interval, sort by memory location when | ||||||
|  |         # there are multiple matches with the same interval size. This | ||||||
|  |         # mirrors the ordering in the intervan tree. | ||||||
|  |         best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2]))) | ||||||
|  |         return best_match[2] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def is_balanced(tree: IntervalTree[T]) -> bool: | ||||||
|  |     """Check if the AVL tree is properly balanced.""" | ||||||
|  |  | ||||||
|  |     def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]: | ||||||
|  |         if node is None: | ||||||
|  |             return True, 0 | ||||||
|  |  | ||||||
|  |         # Left and right subtrees should be balanced. | ||||||
|  |         left_balanced, left_height = check_balance(node.left) | ||||||
|  |         if not left_balanced: | ||||||
|  |             return False, -1 | ||||||
|  |  | ||||||
|  |         right_balanced, right_height = check_balance(node.right) | ||||||
|  |         if not right_balanced: | ||||||
|  |             return False, -1 | ||||||
|  |  | ||||||
|  |         # The difference in height should not exceed 1. | ||||||
|  |         if abs(left_height - right_height) > 1: | ||||||
|  |             return False, -1 | ||||||
|  |  | ||||||
|  |         # Check if the height is correct. | ||||||
|  |         expected_height = 1 + max(left_height, right_height) | ||||||
|  |         if node.height != expected_height: | ||||||
|  |             return False, -1 | ||||||
|  |  | ||||||
|  |         return True, expected_height | ||||||
|  |  | ||||||
|  |     balanced, _ = check_balance(tree.root) | ||||||
|  |     return balanced | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def populated_tree() -> IntervalTree[str]: | ||||||
|  |     """Provides a pre-populated IntervalTree for testing.""" | ||||||
|  |     tree = IntervalTree[str]() | ||||||
|  |     kernels = [ | ||||||
|  |         (80, 89, "Kernel_A_General_80_89"), | ||||||
|  |         (86, 89, "Kernel_B_Ampere_86_89"), | ||||||
|  |         (80, 86, "Kernel_C_Older_Ampere_80_86"), | ||||||
|  |         (70, 75, "Kernel_D_Volta_70_75"), | ||||||
|  |         (86, 87, "Kernel_E_Specific_86_87"), | ||||||
|  |     ] | ||||||
|  |     for start, end, name in kernels: | ||||||
|  |         tree.insert(start, end, name) | ||||||
|  |     return tree | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree): | ||||||
|  |     # Check that the smallest inteval is selected when there are | ||||||
|  |     # multiple matching intervals. | ||||||
|  |     assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_find_single_match(populated_tree): | ||||||
|  |     assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75" | ||||||
|  |     assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_no_match_outside_all_ranges(populated_tree): | ||||||
|  |     # Check that no interval is found when the value is out of range | ||||||
|  |     # (too small/too large). | ||||||
|  |     assert populated_tree.find_smallest_interval(65) is None | ||||||
|  |     assert populated_tree.find_smallest_interval(95) is None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_no_match_in_gap_between_ranges(populated_tree): | ||||||
|  |     # Check that no interval is found when the value is between two | ||||||
|  |     # intervals. | ||||||
|  |     assert populated_tree.find_smallest_interval(78) is None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_boundary_conditions_start_and_end(populated_tree): | ||||||
|  |     # Test exact upper/lower bounds of intervals. | ||||||
|  |     assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86" | ||||||
|  |     assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_empty_tree(): | ||||||
|  |     # Searching in an empty tree should return None. | ||||||
|  |     empty_tree = IntervalTree[str]() | ||||||
|  |     assert empty_tree.find_smallest_interval(100) is None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_multiple_equally_specific_matches(): | ||||||
|  |     # Check that we pick the match in a stable way when there is are | ||||||
|  |     # multiple matching intervals with the same size. | ||||||
|  |     tree = IntervalTree[str]() | ||||||
|  |     str1 = "First_Narrow_Kernel" | ||||||
|  |     str2 = "Second_Narrow_Kernel" | ||||||
|  |     tree.insert(10, 20, "Wide_Kernel") | ||||||
|  |     tree.insert(12, 17, str1) | ||||||
|  |     tree.insert(14, 19, str2) | ||||||
|  |  | ||||||
|  |     if id(str1) < id(str2): | ||||||
|  |         assert tree.find_smallest_interval(15) == str1 | ||||||
|  |     else: | ||||||
|  |         assert tree.find_smallest_interval(15) == str2 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_property_based_interval_tree(): | ||||||
|  |     # Quick-check property-based testing: | ||||||
|  |     # | ||||||
|  |     # - Verify that the tree is balanced after each insertion. | ||||||
|  |     # - Verify the query against a simple list-based implementation. | ||||||
|  |  | ||||||
|  |     random.seed(42)  # For reproducible tests | ||||||
|  |  | ||||||
|  |     test_points = list(range(0, 101)) | ||||||
|  |  | ||||||
|  |     for _ in range(5): | ||||||
|  |         tree = IntervalTree[str]() | ||||||
|  |         simple = SimpleIntervalStore[str]() | ||||||
|  |  | ||||||
|  |         intervals = [] | ||||||
|  |         for i in range(100): | ||||||
|  |             start = random.randint(0, 90) | ||||||
|  |             end = random.randint(start, 100) | ||||||
|  |             data = f"interval_{i}_s{start}_e{end}" | ||||||
|  |             intervals.append((start, end, data)) | ||||||
|  |  | ||||||
|  |         for i, (start, end, data) in enumerate(intervals): | ||||||
|  |             tree.insert(start, end, data) | ||||||
|  |             simple.insert(start, end, data) | ||||||
|  |  | ||||||
|  |             # Check that tree is still balanced | ||||||
|  |             assert is_balanced( | ||||||
|  |                 tree | ||||||
|  |             ), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})" | ||||||
|  |  | ||||||
|  |             for point in test_points: | ||||||
|  |                 tree_result = tree.find_smallest_interval(point) | ||||||
|  |                 simple_result = simple.find_smallest_interval(point) | ||||||
|  |  | ||||||
|  |                 assert tree_result == simple_result, ( | ||||||
|  |                     f"Mismatch for point {point} after inserting {i+1} intervals. " | ||||||
|  |                     f"Tree: {tree_result}, Simple: {simple_result}. " | ||||||
|  |                     f"Last inserted: ({start}, {end})" | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_property_based_edge_cases(): | ||||||
|  |     random.seed(123) | ||||||
|  |  | ||||||
|  |     tree = IntervalTree[str]() | ||||||
|  |     simple = SimpleIntervalStore[str]() | ||||||
|  |  | ||||||
|  |     # Single-point intervals. | ||||||
|  |     for i in range(10): | ||||||
|  |         point = random.randint(0, 100) | ||||||
|  |         data = f"single_point_{i}_{point}" | ||||||
|  |         tree.insert(point, point, data) | ||||||
|  |         simple.insert(point, point, data) | ||||||
|  |  | ||||||
|  |         assert is_balanced( | ||||||
|  |             tree | ||||||
|  |         ), f"Tree unbalanced after inserting single point {point}" | ||||||
|  |  | ||||||
|  |         # Test the exact point and neighbors | ||||||
|  |         for test_point in [point - 1, point, point + 1]: | ||||||
|  |             if 0 <= test_point <= 100: | ||||||
|  |                 tree_result = tree.find_smallest_interval(test_point) | ||||||
|  |                 simple_result = simple.find_smallest_interval(test_point) | ||||||
|  |                 assert tree_result == simple_result | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_unique_intervals_override(): | ||||||
|  |     """Test that inserting an interval with the same start/end overrides the previous value.""" | ||||||
|  |     tree = IntervalTree[str]() | ||||||
|  |  | ||||||
|  |     tree.insert(10, 20, "original_value") | ||||||
|  |     assert tree.find_smallest_interval(15) == "original_value" | ||||||
|  |  | ||||||
|  |     tree.insert(10, 20, "new_value") | ||||||
|  |     assert tree.find_smallest_interval(15) == "new_value" | ||||||
|  |  | ||||||
|  |     tree.insert(10, 25, "different_interval") | ||||||
|  |     results = tree.search(15) | ||||||
|  |     assert "new_value" in results | ||||||
|  |     assert "different_interval" in results | ||||||
|  |     assert len(results) == 2 | ||||||
|  |  | ||||||
|  |     tree.insert(10, 20, "final_value") | ||||||
|  |     assert tree.find_smallest_interval(15) == "final_value" | ||||||
|  |  | ||||||
|  |     assert is_balanced(tree) | ||||||
| @ -1,8 +1,18 @@ | |||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | import torch.nn as nn | ||||||
|  |  | ||||||
| from kernels import load_kernel | from kernels import load_kernel | ||||||
| from kernels.cli import download_kernels | from kernels.cli import download_kernels | ||||||
|  | from kernels.layer import ( | ||||||
|  |     LockedLayerRepository, | ||||||
|  |     Mode, | ||||||
|  |     kernelize, | ||||||
|  |     use_kernel_forward_from_hub, | ||||||
|  |     use_kernel_mapping, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Mock download arguments class. | # Mock download arguments class. | ||||||
| @ -17,8 +27,35 @@ def test_download_all_hash_validation(): | |||||||
|     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) |     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
| def test_load_locked(): | def test_load_locked(): | ||||||
|     project_dir = Path(__file__).parent / "kernel_locking" |     project_dir = Path(__file__).parent / "kernel_locking" | ||||||
|     # Also validates that hashing works correctly. |     # Also validates that hashing works correctly. | ||||||
|     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) |     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) | ||||||
|     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") |     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.cuda_only | ||||||
|  | def test_layer_locked(): | ||||||
|  |     project_dir = Path(__file__).parent / "layer_locking" | ||||||
|  |  | ||||||
|  |     @use_kernel_forward_from_hub("Version") | ||||||
|  |     class Version(nn.Module): | ||||||
|  |         def forward(self) -> str: | ||||||
|  |             return "0.0.0" | ||||||
|  |  | ||||||
|  |     version = Version() | ||||||
|  |  | ||||||
|  |     with use_kernel_mapping( | ||||||
|  |         { | ||||||
|  |             "Version": { | ||||||
|  |                 "cuda": LockedLayerRepository( | ||||||
|  |                     repo_id="kernels-test/versions", | ||||||
|  |                     layer_name="Version", | ||||||
|  |                     lockfile=project_dir / "kernels.lock", | ||||||
|  |                 ) | ||||||
|  |             }, | ||||||
|  |         } | ||||||
|  |     ): | ||||||
|  |         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||||
|  |         assert version() == "0.1.1" | ||||||
|  | |||||||
							
								
								
									
										117
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | |||||||
|  | import logging | ||||||
|  | import os | ||||||
|  | import re | ||||||
|  | import tempfile | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import List | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  | from huggingface_hub import delete_repo, model_info | ||||||
|  |  | ||||||
|  | from kernels.cli import upload_kernels | ||||||
|  |  | ||||||
|  | REPO_ID = "valid_org/kernels-upload-test" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | PY_CONTENT = """\ | ||||||
|  | #!/usr/bin/env python3 | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     print("Hello from torch-universal!") | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
|  | """ | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class UploadArgs: | ||||||
|  |     kernel_dir: None | ||||||
|  |     repo_id: None | ||||||
|  |     private: False | ||||||
|  |     branch: None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def next_filename(path: Path) -> Path: | ||||||
|  |     """ | ||||||
|  |     Given a path like foo_2050.py, return foo_2051.py. | ||||||
|  |     """ | ||||||
|  |     m = re.match(r"^(.*?)(\d+)(\.py)$", path.name) | ||||||
|  |     if not m: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Filename {path.name!r} does not match pattern <prefix>_<number>.py" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     prefix, number, suffix = m.groups() | ||||||
|  |     new_number = str(int(number) + 1).zfill(len(number)) | ||||||
|  |     return path.with_name(f"{prefix}{new_number}{suffix}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_filename_to_change(repo_filenames): | ||||||
|  |     for f in repo_filenames: | ||||||
|  |         if "foo" in f and f.endswith(".py"): | ||||||
|  |             filename_to_change = os.path.basename(f) | ||||||
|  |             break | ||||||
|  |     assert filename_to_change | ||||||
|  |     return filename_to_change | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_filenames_from_a_repo(repo_id: str) -> List[str]: | ||||||
|  |     try: | ||||||
|  |         repo_info = model_info(repo_id=repo_id, files_metadata=True) | ||||||
|  |         repo_siblings = repo_info.siblings | ||||||
|  |         if repo_siblings is not None: | ||||||
|  |             return [f.rfilename for f in repo_siblings] | ||||||
|  |         else: | ||||||
|  |             raise ValueError("No repo siblings found.") | ||||||
|  |     except Exception as e: | ||||||
|  |         logging.error(f"Error connecting to the Hub: {e}.") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.token | ||||||
|  | @pytest.mark.is_staging_test | ||||||
|  | @pytest.mark.parametrize("branch", (None, "foo")) | ||||||
|  | def test_kernel_upload_works_as_expected(branch): | ||||||
|  |     with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|  |         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||||
|  |         build_dir = Path(path) | ||||||
|  |         build_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |         script_path = build_dir / "foo.py" | ||||||
|  |         script_path.write_text(PY_CONTENT) | ||||||
|  |         upload_kernels(UploadArgs(tmpdir, REPO_ID, False, branch)) | ||||||
|  |  | ||||||
|  |     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||||
|  |     assert any(str(script_path.name) for f in repo_filenames) | ||||||
|  |     delete_repo(repo_id=REPO_ID) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.token | ||||||
|  | @pytest.mark.is_staging_test | ||||||
|  | def test_kernel_upload_deletes_as_expected(): | ||||||
|  |     with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|  |         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||||
|  |         build_dir = Path(path) | ||||||
|  |         build_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |         script_path = build_dir / "foo_2025.py" | ||||||
|  |         script_path.write_text(PY_CONTENT) | ||||||
|  |         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||||
|  |  | ||||||
|  |     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||||
|  |     filename_to_change = get_filename_to_change(repo_filenames) | ||||||
|  |  | ||||||
|  |     with tempfile.TemporaryDirectory() as tmpdir: | ||||||
|  |         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||||
|  |         build_dir = Path(path) | ||||||
|  |         build_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |         changed_filename = next_filename(Path(filename_to_change)) | ||||||
|  |         script_path = build_dir / changed_filename | ||||||
|  |         script_path.write_text(PY_CONTENT) | ||||||
|  |         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||||
|  |  | ||||||
|  |     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||||
|  |     assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}" | ||||||
|  |     assert not any( | ||||||
|  |         str(filename_to_change) in k for k in repo_filenames | ||||||
|  |     ), f"{repo_filenames=}" | ||||||
|  |     delete_repo(repo_id=REPO_ID) | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	