mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-11-04 14:14:31 +08:00 
			
		
		
		
	Compare commits
	
		
			48 Commits
		
	
	
		
			bump-0.1.6
			...
			v0.4.4
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| cf530c283a | |||
| 437f910336 | |||
| 6f1a6067c8 | |||
| 1d14abcef0 | |||
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | |||
| 22eaba2826 | |||
| 9521ba79a0 | |||
| 9861a5bdef | |||
| 1c7c87c960 | |||
| df45cf2795 | |||
| cf0413efe5 | |||
| 851c13f666 | |||
| b6a393612f | |||
| 18ecd0ce69 | |||
| b4ef1d60e5 | |||
| a40756f306 | |||
| 3671158f47 | |||
| 2ddd473cf7 | |||
| 497dffb89e | |||
| f036fd09cb | |||
| 3e4c83c798 | |||
| 4116d6019e | |||
| bd166b348a | |||
| 386c2a104e | |||
| c7516b9e50 | |||
| a8dcd1f6bc | |||
| af7fdf9202 | |||
| 9426e7e290 | |||
| df2c165d61 | |||
| d89239464a | |||
| 3212affd9e | |||
| 7ff40a859c | 
							
								
								
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					name: Lints
 | 
				
			||||||
 | 
					on: [push, pull_request]
 | 
				
			||||||
 | 
					jobs:
 | 
				
			||||||
 | 
					  lint:
 | 
				
			||||||
 | 
					    name: Run lints
 | 
				
			||||||
 | 
					    runs-on: ubuntu-latest
 | 
				
			||||||
 | 
					    steps:
 | 
				
			||||||
 | 
					      - uses: actions/checkout@v4
 | 
				
			||||||
 | 
					      - name: Run ruff
 | 
				
			||||||
 | 
					        uses: astral-sh/ruff-action@v3
 | 
				
			||||||
							
								
								
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@ -1,4 +1,4 @@
 | 
				
			|||||||
name: Test hf-kernels
 | 
					name: Test kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
on:
 | 
					on:
 | 
				
			||||||
  push:
 | 
					  push:
 | 
				
			||||||
@ -26,6 +26,9 @@ jobs:
 | 
				
			|||||||
        python-version: ["3.10", "3.12"]
 | 
					        python-version: ["3.10", "3.12"]
 | 
				
			||||||
        torch-version: ["2.5.1", "2.6.0"]
 | 
					        torch-version: ["2.5.1", "2.6.0"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    env:
 | 
				
			||||||
 | 
					      UV_PYTHON_PREFERENCE: only-managed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    steps:
 | 
					    steps:
 | 
				
			||||||
      - name: Checkout code
 | 
					      - name: Checkout code
 | 
				
			||||||
        uses: actions/checkout@v4
 | 
					        uses: actions/checkout@v4
 | 
				
			||||||
@ -41,5 +44,16 @@ jobs:
 | 
				
			|||||||
      - name: Install the project
 | 
					      - name: Install the project
 | 
				
			||||||
        run: uv sync --all-extras --dev
 | 
					        run: uv sync --all-extras --dev
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Install setuptools for Triton-based test
 | 
				
			||||||
 | 
					        run: uv pip install setuptools
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Check typing
 | 
				
			||||||
 | 
					        run: uv run mypy src/kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      - name: Run tests
 | 
					      - name: Run tests
 | 
				
			||||||
        run: uv run pytest tests
 | 
					        run: uv run pytest tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      - name: Import check without torch
 | 
				
			||||||
 | 
					        run: |
 | 
				
			||||||
 | 
					          uv pip uninstall torch
 | 
				
			||||||
 | 
					          python -c "import kernels"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,201 @@
 | 
				
			|||||||
 | 
					                                 Apache License
 | 
				
			||||||
 | 
					                           Version 2.0, January 2004
 | 
				
			||||||
 | 
					                        http://www.apache.org/licenses/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   1. Definitions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "License" shall mean the terms and conditions for use, reproduction,
 | 
				
			||||||
 | 
					      and distribution as defined by Sections 1 through 9 of this document.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Licensor" shall mean the copyright owner or entity authorized by
 | 
				
			||||||
 | 
					      the copyright owner that is granting the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Legal Entity" shall mean the union of the acting entity and all
 | 
				
			||||||
 | 
					      other entities that control, are controlled by, or are under common
 | 
				
			||||||
 | 
					      control with that entity. For the purposes of this definition,
 | 
				
			||||||
 | 
					      "control" means (i) the power, direct or indirect, to cause the
 | 
				
			||||||
 | 
					      direction or management of such entity, whether by contract or
 | 
				
			||||||
 | 
					      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
				
			||||||
 | 
					      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "You" (or "Your") shall mean an individual or Legal Entity
 | 
				
			||||||
 | 
					      exercising permissions granted by this License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Source" form shall mean the preferred form for making modifications,
 | 
				
			||||||
 | 
					      including but not limited to software source code, documentation
 | 
				
			||||||
 | 
					      source, and configuration files.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Object" form shall mean any form resulting from mechanical
 | 
				
			||||||
 | 
					      transformation or translation of a Source form, including but
 | 
				
			||||||
 | 
					      not limited to compiled object code, generated documentation,
 | 
				
			||||||
 | 
					      and conversions to other media types.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Work" shall mean the work of authorship, whether in Source or
 | 
				
			||||||
 | 
					      Object form, made available under the License, as indicated by a
 | 
				
			||||||
 | 
					      copyright notice that is included in or attached to the work
 | 
				
			||||||
 | 
					      (an example is provided in the Appendix below).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Derivative Works" shall mean any work, whether in Source or Object
 | 
				
			||||||
 | 
					      form, that is based on (or derived from) the Work and for which the
 | 
				
			||||||
 | 
					      editorial revisions, annotations, elaborations, or other modifications
 | 
				
			||||||
 | 
					      represent, as a whole, an original work of authorship. For the purposes
 | 
				
			||||||
 | 
					      of this License, Derivative Works shall not include works that remain
 | 
				
			||||||
 | 
					      separable from, or merely link (or bind by name) to the interfaces of,
 | 
				
			||||||
 | 
					      the Work and Derivative Works thereof.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Contribution" shall mean any work of authorship, including
 | 
				
			||||||
 | 
					      the original version of the Work and any modifications or additions
 | 
				
			||||||
 | 
					      to that Work or Derivative Works thereof, that is intentionally
 | 
				
			||||||
 | 
					      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
				
			||||||
 | 
					      or by an individual or Legal Entity authorized to submit on behalf of
 | 
				
			||||||
 | 
					      the copyright owner. For the purposes of this definition, "submitted"
 | 
				
			||||||
 | 
					      means any form of electronic, verbal, or written communication sent
 | 
				
			||||||
 | 
					      to the Licensor or its representatives, including but not limited to
 | 
				
			||||||
 | 
					      communication on electronic mailing lists, source code control systems,
 | 
				
			||||||
 | 
					      and issue tracking systems that are managed by, or on behalf of, the
 | 
				
			||||||
 | 
					      Licensor for the purpose of discussing and improving the Work, but
 | 
				
			||||||
 | 
					      excluding communication that is conspicuously marked or otherwise
 | 
				
			||||||
 | 
					      designated in writing by the copyright owner as "Not a Contribution."
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
				
			||||||
 | 
					      on behalf of whom a Contribution has been received by Licensor and
 | 
				
			||||||
 | 
					      subsequently incorporated within the Work.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
				
			||||||
 | 
					      this License, each Contributor hereby grants to You a perpetual,
 | 
				
			||||||
 | 
					      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
				
			||||||
 | 
					      copyright license to reproduce, prepare Derivative Works of,
 | 
				
			||||||
 | 
					      publicly display, publicly perform, sublicense, and distribute the
 | 
				
			||||||
 | 
					      Work and such Derivative Works in Source or Object form.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   3. Grant of Patent License. Subject to the terms and conditions of
 | 
				
			||||||
 | 
					      this License, each Contributor hereby grants to You a perpetual,
 | 
				
			||||||
 | 
					      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
				
			||||||
 | 
					      (except as stated in this section) patent license to make, have made,
 | 
				
			||||||
 | 
					      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
				
			||||||
 | 
					      where such license applies only to those patent claims licensable
 | 
				
			||||||
 | 
					      by such Contributor that are necessarily infringed by their
 | 
				
			||||||
 | 
					      Contribution(s) alone or by combination of their Contribution(s)
 | 
				
			||||||
 | 
					      with the Work to which such Contribution(s) was submitted. If You
 | 
				
			||||||
 | 
					      institute patent litigation against any entity (including a
 | 
				
			||||||
 | 
					      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
				
			||||||
 | 
					      or a Contribution incorporated within the Work constitutes direct
 | 
				
			||||||
 | 
					      or contributory patent infringement, then any patent licenses
 | 
				
			||||||
 | 
					      granted to You under this License for that Work shall terminate
 | 
				
			||||||
 | 
					      as of the date such litigation is filed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   4. Redistribution. You may reproduce and distribute copies of the
 | 
				
			||||||
 | 
					      Work or Derivative Works thereof in any medium, with or without
 | 
				
			||||||
 | 
					      modifications, and in Source or Object form, provided that You
 | 
				
			||||||
 | 
					      meet the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      (a) You must give any other recipients of the Work or
 | 
				
			||||||
 | 
					          Derivative Works a copy of this License; and
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      (b) You must cause any modified files to carry prominent notices
 | 
				
			||||||
 | 
					          stating that You changed the files; and
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      (c) You must retain, in the Source form of any Derivative Works
 | 
				
			||||||
 | 
					          that You distribute, all copyright, patent, trademark, and
 | 
				
			||||||
 | 
					          attribution notices from the Source form of the Work,
 | 
				
			||||||
 | 
					          excluding those notices that do not pertain to any part of
 | 
				
			||||||
 | 
					          the Derivative Works; and
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      (d) If the Work includes a "NOTICE" text file as part of its
 | 
				
			||||||
 | 
					          distribution, then any Derivative Works that You distribute must
 | 
				
			||||||
 | 
					          include a readable copy of the attribution notices contained
 | 
				
			||||||
 | 
					          within such NOTICE file, excluding those notices that do not
 | 
				
			||||||
 | 
					          pertain to any part of the Derivative Works, in at least one
 | 
				
			||||||
 | 
					          of the following places: within a NOTICE text file distributed
 | 
				
			||||||
 | 
					          as part of the Derivative Works; within the Source form or
 | 
				
			||||||
 | 
					          documentation, if provided along with the Derivative Works; or,
 | 
				
			||||||
 | 
					          within a display generated by the Derivative Works, if and
 | 
				
			||||||
 | 
					          wherever such third-party notices normally appear. The contents
 | 
				
			||||||
 | 
					          of the NOTICE file are for informational purposes only and
 | 
				
			||||||
 | 
					          do not modify the License. You may add Your own attribution
 | 
				
			||||||
 | 
					          notices within Derivative Works that You distribute, alongside
 | 
				
			||||||
 | 
					          or as an addendum to the NOTICE text from the Work, provided
 | 
				
			||||||
 | 
					          that such additional attribution notices cannot be construed
 | 
				
			||||||
 | 
					          as modifying the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      You may add Your own copyright statement to Your modifications and
 | 
				
			||||||
 | 
					      may provide additional or different license terms and conditions
 | 
				
			||||||
 | 
					      for use, reproduction, or distribution of Your modifications, or
 | 
				
			||||||
 | 
					      for any such Derivative Works as a whole, provided Your use,
 | 
				
			||||||
 | 
					      reproduction, and distribution of the Work otherwise complies with
 | 
				
			||||||
 | 
					      the conditions stated in this License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
				
			||||||
 | 
					      any Contribution intentionally submitted for inclusion in the Work
 | 
				
			||||||
 | 
					      by You to the Licensor shall be under the terms and conditions of
 | 
				
			||||||
 | 
					      this License, without any additional terms or conditions.
 | 
				
			||||||
 | 
					      Notwithstanding the above, nothing herein shall supersede or modify
 | 
				
			||||||
 | 
					      the terms of any separate license agreement you may have executed
 | 
				
			||||||
 | 
					      with Licensor regarding such Contributions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   6. Trademarks. This License does not grant permission to use the trade
 | 
				
			||||||
 | 
					      names, trademarks, service marks, or product names of the Licensor,
 | 
				
			||||||
 | 
					      except as required for reasonable and customary use in describing the
 | 
				
			||||||
 | 
					      origin of the Work and reproducing the content of the NOTICE file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
				
			||||||
 | 
					      agreed to in writing, Licensor provides the Work (and each
 | 
				
			||||||
 | 
					      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
				
			||||||
 | 
					      implied, including, without limitation, any warranties or conditions
 | 
				
			||||||
 | 
					      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
				
			||||||
 | 
					      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
				
			||||||
 | 
					      appropriateness of using or redistributing the Work and assume any
 | 
				
			||||||
 | 
					      risks associated with Your exercise of permissions under this License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   8. Limitation of Liability. In no event and under no legal theory,
 | 
				
			||||||
 | 
					      whether in tort (including negligence), contract, or otherwise,
 | 
				
			||||||
 | 
					      unless required by applicable law (such as deliberate and grossly
 | 
				
			||||||
 | 
					      negligent acts) or agreed to in writing, shall any Contributor be
 | 
				
			||||||
 | 
					      liable to You for damages, including any direct, indirect, special,
 | 
				
			||||||
 | 
					      incidental, or consequential damages of any character arising as a
 | 
				
			||||||
 | 
					      result of this License or out of the use or inability to use the
 | 
				
			||||||
 | 
					      Work (including but not limited to damages for loss of goodwill,
 | 
				
			||||||
 | 
					      work stoppage, computer failure or malfunction, or any and all
 | 
				
			||||||
 | 
					      other commercial damages or losses), even if such Contributor
 | 
				
			||||||
 | 
					      has been advised of the possibility of such damages.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   9. Accepting Warranty or Additional Liability. While redistributing
 | 
				
			||||||
 | 
					      the Work or Derivative Works thereof, You may choose to offer,
 | 
				
			||||||
 | 
					      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
				
			||||||
 | 
					      or other liability obligations and/or rights consistent with this
 | 
				
			||||||
 | 
					      License. However, in accepting such obligations, You may act only
 | 
				
			||||||
 | 
					      on Your own behalf and on Your sole responsibility, not on behalf
 | 
				
			||||||
 | 
					      of any other Contributor, and only if You agree to indemnify,
 | 
				
			||||||
 | 
					      defend, and hold each Contributor harmless for any liability
 | 
				
			||||||
 | 
					      incurred by, or claims asserted against, such Contributor by reason
 | 
				
			||||||
 | 
					      of your accepting any such warranty or additional liability.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   END OF TERMS AND CONDITIONS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   APPENDIX: How to apply the Apache License to your work.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      To apply the Apache License to your work, attach the following
 | 
				
			||||||
 | 
					      boilerplate notice, with the fields enclosed by brackets "[]"
 | 
				
			||||||
 | 
					      replaced with your own identifying information. (Don't include
 | 
				
			||||||
 | 
					      the brackets!)  The text should be enclosed in the appropriate
 | 
				
			||||||
 | 
					      comment syntax for the file format. We also recommend that a
 | 
				
			||||||
 | 
					      file or class name and description of purpose be included on the
 | 
				
			||||||
 | 
					      same "printed page" as the copyright notice for easier
 | 
				
			||||||
 | 
					      identification within third-party archives.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Copyright [yyyy] [name of copyright owner]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					   you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					   You may obtain a copy of the License at
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					       http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					   distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					   See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					   limitations under the License.
 | 
				
			||||||
							
								
								
									
										98
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										98
									
								
								README.md
									
									
									
									
									
								
							@ -1,11 +1,42 @@
 | 
				
			|||||||
# hf-kernels
 | 
					# kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Make sure you have `torch==2.5.1+cu124` installed.
 | 
					<div align="center">
 | 
				
			||||||
 | 
					<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
 | 
				
			||||||
 | 
					<p align="center">
 | 
				
			||||||
 | 
					    <a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
 | 
				
			||||||
 | 
					    <a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
 | 
				
			||||||
 | 
					    <a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					</p>
 | 
				
			||||||
 | 
					</div>
 | 
				
			||||||
 | 
					<hr/>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The Kernel Hub allows Python libraries and applications to load compute
 | 
				
			||||||
 | 
					kernels directly from the [Hub](https://hf.co/). To support this kind
 | 
				
			||||||
 | 
					of dynamic loading, Hub kernels differ from traditional Python kernel
 | 
				
			||||||
 | 
					packages in that they are made to be:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Portable: a kernel can be loaded from paths outside `PYTHONPATH`.
 | 
				
			||||||
 | 
					- Unique: multiple versions of the same kernel can be loaded in the
 | 
				
			||||||
 | 
					  same Python process.
 | 
				
			||||||
 | 
					- Compatible: kernels must support all recent versions of Python and
 | 
				
			||||||
 | 
					  the different PyTorch build configurations (various CUDA versions
 | 
				
			||||||
 | 
					  and C++ ABIs). Furthermore, older C library versions must be supported.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 🚀 Quick Start
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					pip install kernels
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```python
 | 
					```python
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hf_kernels import get_kernel
 | 
					from kernels import get_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Download optimized kernels from the Hugging Face hub
 | 
					# Download optimized kernels from the Hugging Face hub
 | 
				
			||||||
activation = get_kernel("kernels-community/activation")
 | 
					activation = get_kernel("kernels-community/activation")
 | 
				
			||||||
@ -20,57 +51,14 @@ activation.gelu_fast(y, x)
 | 
				
			|||||||
print(y)
 | 
					print(y)
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Docker Reference
 | 
					You can [search for kernels](https://huggingface.co/models?other=kernel) on
 | 
				
			||||||
 | 
					the Hub.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
build and run the reference [example/basic.py](example/basic.py) in a Docker container with the following commands:
 | 
					## 📚 Documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					- [Using layers](docs/layers.md)
 | 
				
			||||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
 | 
					- [Locking kernel versions](docs/locking.md)
 | 
				
			||||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
 | 
					- [Environment variables](docs/env.md)
 | 
				
			||||||
```
 | 
					- [Using kernels in a Docker container](docs/docker.md)
 | 
				
			||||||
 | 
					- [Kernel requirements](docs/kernel-requirements.md)
 | 
				
			||||||
## Locking kernel versions
 | 
					- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
 | 
				
			||||||
 | 
					 | 
				
			||||||
Projects that use `setuptools` can lock the kernel versions that should be
 | 
					 | 
				
			||||||
used. First specify the accepted versions in `pyproject.toml` and make
 | 
					 | 
				
			||||||
sure that `hf-kernels` is a build dependency:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
```toml
 | 
					 | 
				
			||||||
[build-system]
 | 
					 | 
				
			||||||
requires = ["hf-kernels", "setuptools"]
 | 
					 | 
				
			||||||
build-backend = "setuptools.build_meta"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[tool.kernels.dependencies]
 | 
					 | 
				
			||||||
"kernels-community/activation" = ">=0.0.1"
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with
 | 
					 | 
				
			||||||
the locked revisions. The locked revision will be used when loading a kernel with
 | 
					 | 
				
			||||||
`get_locked_kernel`:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
```python
 | 
					 | 
				
			||||||
from hf_kernels import get_locked_kernel
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
activation = get_locked_kernel("kernels-community/activation")
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
**Note:** the lock file is included in the package metadata, so it will only be visible
 | 
					 | 
				
			||||||
to `hf-kernels` after doing an (editable or regular) installation of your project.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
## Pre-downloading locked kernels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Locked kernels can be pre-downloaded by running `hf-kernel download .` in your
 | 
					 | 
				
			||||||
project directory. This will download the kernels to your local Hugging Face
 | 
					 | 
				
			||||||
Hub cache.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
The pre-downloaded kernels are used by the `get_locked_kernel` function.
 | 
					 | 
				
			||||||
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
 | 
					 | 
				
			||||||
want kernel loading to error when a kernel is not pre-downloaded, you can use
 | 
					 | 
				
			||||||
the `load_kernel` function instead:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
````python
 | 
					 | 
				
			||||||
```python
 | 
					 | 
				
			||||||
from hf_kernels import load_kernel
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
activation = load_kernel("kernels-community/activation")
 | 
					 | 
				
			||||||
````
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -31,13 +31,13 @@ WORKDIR /app/kernel-test
 | 
				
			|||||||
# install python depdencies
 | 
					# install python depdencies
 | 
				
			||||||
RUN uv add torch==2.5.0 numpy
 | 
					RUN uv add torch==2.5.0 numpy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# copy hf-kernels lib
 | 
					# copy kernels lib
 | 
				
			||||||
COPY src ./hf-kernels/src
 | 
					COPY src ./kernels/src
 | 
				
			||||||
COPY pyproject.toml ./hf-kernels/pyproject.toml
 | 
					COPY pyproject.toml ./kernels/pyproject.toml
 | 
				
			||||||
COPY README.md ./hf-kernels/README.md
 | 
					COPY README.md ./kernels/README.md
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# install library
 | 
					# install library
 | 
				
			||||||
RUN uv pip install -e hf-kernels
 | 
					RUN uv pip install -e kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# copy examples
 | 
					# copy examples
 | 
				
			||||||
COPY examples ./examples
 | 
					COPY examples ./examples
 | 
				
			||||||
@ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# command to run the script
 | 
					# command to run the script
 | 
				
			||||||
CMD ["uv", "run", "examples/basic.py"]
 | 
					CMD ["uv", "run", "examples/basic.py"]
 | 
				
			||||||
# CMD ["ls", "hf-kernels"]
 | 
					# CMD ["ls", "kernels"]
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					# Using kernels in a Docker container
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
 | 
				
			||||||
 | 
					docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
							
								
								
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/env.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
				
			|||||||
 | 
					# Environment variables
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## `KERNELS_CACHE`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The directory to use as the local kernel cache. If not set, the cache
 | 
				
			||||||
 | 
					of the `huggingface_hub` package is used.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## `DISABLE_KERNEL_MAPPING`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Disables kernel mappings for [`layers`](layers.md).
 | 
				
			||||||
							
								
								
									
										205
									
								
								docs/kernel-requirements.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								docs/kernel-requirements.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,205 @@
 | 
				
			|||||||
 | 
					# Kernel requirements
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Kernels on the Hub must fulfill the requirements outlined on this page.
 | 
				
			||||||
 | 
					You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
 | 
				
			||||||
 | 
					to build conforming kernels.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Directory layout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A kernel repository on the Hub must contain a `build` directory. This
 | 
				
			||||||
 | 
					directory contains build variants of a kernel in the form of directories
 | 
				
			||||||
 | 
					following the template
 | 
				
			||||||
 | 
					`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
 | 
				
			||||||
 | 
					For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
 | 
				
			||||||
 | 
					recommended build variants are:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- `torch25-cxx11-cu118-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch25-cxx11-cu121-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch25-cxx11-cu124-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch25-cxx98-cu118-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch25-cxx98-cu121-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch25-cxx98-cu124-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx11-cu118-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx11-cu124-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx11-cu126-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx98-cu118-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx98-cu124-x86_64-linux`
 | 
				
			||||||
 | 
					- `torch26-cxx98-cu126-x86_64-linux`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This list will be updated as new PyTorch versions are released. Kernels
 | 
				
			||||||
 | 
					that are in pure Python (e.g. Triton kernels) only need to provide a
 | 
				
			||||||
 | 
					single build variant:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- `torch-universal`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Each variant directory should contain a single directory with the same name
 | 
				
			||||||
 | 
					as the repository (replacing `-` by `_`). For instance, kernels in the
 | 
				
			||||||
 | 
					`kernels-community/activation` repository have a directories like
 | 
				
			||||||
 | 
					`build/<variant>/activation`. This directory
 | 
				
			||||||
 | 
					must be a Python package with an `__init__.py` file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Versioning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Kernels are versioned on the Hub using Git tags. Version tags must be of
 | 
				
			||||||
 | 
					the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
 | 
				
			||||||
 | 
					to resolve the version constraints.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Native Python module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Kernels will typically contain a native Python module with precompiled
 | 
				
			||||||
 | 
					compute kernels and bindings. This module must fulfill the following
 | 
				
			||||||
 | 
					requirements:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
 | 
				
			||||||
 | 
					  for compatibility with Python 3.9 and later.
 | 
				
			||||||
 | 
					- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
 | 
				
			||||||
 | 
					  This means that the extension **must not** use symbols versions higher than:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - GLIBC 2.28
 | 
				
			||||||
 | 
					  - GLIBCXX 3.4.24
 | 
				
			||||||
 | 
					  - CXXABI 1.3.11
 | 
				
			||||||
 | 
					  - GCC 7.0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  These requirement can be checked with the ABI checker (see below).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- No dynamic library dependencies outside:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  - Torch;
 | 
				
			||||||
 | 
					  - CUDA/ROCm libraries installed as dependencies of Torch.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
 | 
				
			||||||
 | 
					[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					$ cargo install kernel-abi-check
 | 
				
			||||||
 | 
					$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
 | 
				
			||||||
 | 
					🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
 | 
				
			||||||
 | 
					✅ No compatibility issues found
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Torch extension
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
 | 
				
			||||||
 | 
					in `torch.ops.<namespace>`. Since we allow loading of multiple versions of
 | 
				
			||||||
 | 
					a module in the same Python process, `namespace` must be unique for each
 | 
				
			||||||
 | 
					version of a kernel. Failing to do so will create clashes when different
 | 
				
			||||||
 | 
					versions of the same kernel are loaded. Two suggested ways of doing this
 | 
				
			||||||
 | 
					are:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Appending a truncated SHA-1 hash of the git commit that the kernel was
 | 
				
			||||||
 | 
					  built from to the name of the extension.
 | 
				
			||||||
 | 
					- Appending random material to the name of the extension.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Note:** we recommend against appending a version number or git tag.
 | 
				
			||||||
 | 
					Version numbers are typically not bumped on each commit, so users
 | 
				
			||||||
 | 
					might use two different commits that happen to have the same version
 | 
				
			||||||
 | 
					number. Git tags are not stable, so they do not provide a good way
 | 
				
			||||||
 | 
					of guaranteeing uniqueness of the namespace.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A kernel can provide layers in addition to kernel functions. A layer from
 | 
				
			||||||
 | 
					the Hub can replace the `forward` method of an existing layer for a certain
 | 
				
			||||||
 | 
					device type. This makes it possible to provide more performant kernels for
 | 
				
			||||||
 | 
					existing layers. See the [layers documentation](layers.md) for more information
 | 
				
			||||||
 | 
					on how to use layers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Writing layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					To make the extension of layers safe, the layers must fulfill the following
 | 
				
			||||||
 | 
					requirements:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- The layers are subclasses of `torch.nn.Module`.
 | 
				
			||||||
 | 
					- The layers are pure, meaning that they do not have their own state. This
 | 
				
			||||||
 | 
					  means that:
 | 
				
			||||||
 | 
					  - The layer must not define its own constructor.
 | 
				
			||||||
 | 
					  - The layer must not use class variables.
 | 
				
			||||||
 | 
					- No other methods must be defined than `forward`.
 | 
				
			||||||
 | 
					- The `forward` method has a signature that is compatible with the
 | 
				
			||||||
 | 
					  `forward` method that it is extending.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The only exception to the _no class variables rule_ is addition of a
 | 
				
			||||||
 | 
					`has_backward` class variable. This variable is used to indicate whether
 | 
				
			||||||
 | 
					the layer has a backward pass implemented (`True` when absent).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This is an example of a pure layer:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					class SiluAndMul(nn.Module):
 | 
				
			||||||
 | 
					    # This layer does not implement backward.
 | 
				
			||||||
 | 
					    has_backward: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
 | 
					        d = x.shape[-1] // 2
 | 
				
			||||||
 | 
					        output_shape = x.shape[:-1] + (d,)
 | 
				
			||||||
 | 
					        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
 | 
				
			||||||
 | 
					        ops.silu_and_mul(out, x)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					For some layers, the `forward` method has to use state from the adopting class.
 | 
				
			||||||
 | 
					In these cases, we recommend to use type annotations to indicate what member
 | 
				
			||||||
 | 
					variables are expected. For instance:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					class LlamaRMSNorm(nn.Module):
 | 
				
			||||||
 | 
					    weight: torch.Tensor
 | 
				
			||||||
 | 
					    variance_epsilon: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        return rms_norm_fn(
 | 
				
			||||||
 | 
					            hidden_states,
 | 
				
			||||||
 | 
					            self.weight,
 | 
				
			||||||
 | 
					            bias=None,
 | 
				
			||||||
 | 
					            residual=None,
 | 
				
			||||||
 | 
					            eps=self.variance_epsilon,
 | 
				
			||||||
 | 
					            dropout_p=0.0,
 | 
				
			||||||
 | 
					            prenorm=False,
 | 
				
			||||||
 | 
					            residual_in_fp32=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This layer expects the adopting layer to have `weight` and `variance_epsilon`
 | 
				
			||||||
 | 
					member variables and uses them in the `forward` method.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Exporting layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					To accommodate portable loading, `layers` must be defined in the main
 | 
				
			||||||
 | 
					`__init__.py` file. For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from . import layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					  # ...
 | 
				
			||||||
 | 
					  "layers"
 | 
				
			||||||
 | 
					  # ...
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Python requirements
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Python code must be compatible with Python 3.9 and later.
 | 
				
			||||||
 | 
					- All Python code imports from the kernel itself must be relative. So,
 | 
				
			||||||
 | 
					  for instance if in the example kernel `example`,
 | 
				
			||||||
 | 
					  `module_b` needs a function from `module_a`, import as:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ```python
 | 
				
			||||||
 | 
					  from .module_a import foo
 | 
				
			||||||
 | 
					  ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  **Never use:**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  ```python
 | 
				
			||||||
 | 
					  # DO NOT DO THIS!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  from example.module_a import foo
 | 
				
			||||||
 | 
					  ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  The latter would import from the module `example` that is in Python's
 | 
				
			||||||
 | 
					  global module dict. However, since we allow loading multiple versions
 | 
				
			||||||
 | 
					  of a module, we uniquely name the module.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- Only modules from the Python standard library, Torch, or the kernel itself
 | 
				
			||||||
 | 
					  can be imported.
 | 
				
			||||||
							
								
								
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,79 @@
 | 
				
			|||||||
 | 
					# Layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A kernel can provide layers in addition to kernel functions. A layer from
 | 
				
			||||||
 | 
					the Hub can replace the `forward` method of an existing layer for a certain
 | 
				
			||||||
 | 
					device type. This makes it possible to provide more performant kernels for
 | 
				
			||||||
 | 
					existing layers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					See [Kernel requirements](kernel-requirements.md) for more information the
 | 
				
			||||||
 | 
					requirements of Hub layers.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Making a layer extensible with kernels from the hub
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Using a decorator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					A layer can be made extensible with the `use_kernel_forward_from_hub`
 | 
				
			||||||
 | 
					decorator. For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					@use_kernel_forward_from_hub("SiluAndMul")
 | 
				
			||||||
 | 
					class SiluAndMul(nn.Module):
 | 
				
			||||||
 | 
					    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        d = input.shape[-1] // 2
 | 
				
			||||||
 | 
					        return F.silu(input[..., :d]) * input[..., d:]
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The decorator changes the layer, so that other implementations of the `forward`
 | 
				
			||||||
 | 
					method can be registered using the name `SiluAndMul`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### External layers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
 | 
				
			||||||
 | 
					decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from somelibrary import SiluAndMul
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
 | 
				
			||||||
 | 
					register_kernel_mapping(kernel_layer_mapping)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
 | 
				
			||||||
 | 
					hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
 | 
				
			||||||
 | 
					section for more information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Warning:** we strongly recommend using layers with a decorator, since
 | 
				
			||||||
 | 
					it signifies that the maintainer intends to keep the `forward` signature
 | 
				
			||||||
 | 
					compatible with layers from the hub.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Registering a hub kernel for a layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Once a layer is made extensible, users can register hub kernels for it
 | 
				
			||||||
 | 
					by name using the `register_kernel_mapping` function. For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					kernel_layer_mapping = {
 | 
				
			||||||
 | 
					    "SiluAndMul": {
 | 
				
			||||||
 | 
					        "cuda": LayerRepository(
 | 
				
			||||||
 | 
					            repo_id="kernels-community/activation",
 | 
				
			||||||
 | 
					            layer_name="SiluAndMul",
 | 
				
			||||||
 | 
					            revision="layers",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					register_kernel_mapping(kernel_layer_mapping)
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This will register the kernel mapping in the current context, which is
 | 
				
			||||||
 | 
					normally global. It is recommended to scope the mapping to where it is
 | 
				
			||||||
 | 
					used with the `use_kernel_mapping` context manager:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					with use_kernel_mapping(kernel_layer_mapping):
 | 
				
			||||||
 | 
					    # Use the layer for which the mapping is applied.
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This ensures that the mapping is not active anymore outside the
 | 
				
			||||||
 | 
					`with`-scope.
 | 
				
			||||||
							
								
								
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,44 @@
 | 
				
			|||||||
 | 
					# Locking kernel versions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Projects that use `setuptools` can lock the kernel versions that should be
 | 
				
			||||||
 | 
					used. First specify the accepted versions in `pyproject.toml` and make
 | 
				
			||||||
 | 
					sure that `kernels` is a build dependency:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```toml
 | 
				
			||||||
 | 
					[build-system]
 | 
				
			||||||
 | 
					requires = ["kernels", "setuptools"]
 | 
				
			||||||
 | 
					build-backend = "setuptools.build_meta"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.kernels.dependencies]
 | 
				
			||||||
 | 
					"kernels-community/activation" = ">=0.0.1"
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
 | 
				
			||||||
 | 
					the locked revisions. The locked revision will be used when loading a kernel with
 | 
				
			||||||
 | 
					`get_locked_kernel`:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from kernels import get_locked_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					activation = get_locked_kernel("kernels-community/activation")
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**Note:** the lock file is included in the package metadata, so it will only be visible
 | 
				
			||||||
 | 
					to `kernels` after doing an (editable or regular) installation of your project.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Pre-downloading locked kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Locked kernels can be pre-downloaded by running `kernel download .` in your
 | 
				
			||||||
 | 
					project directory. This will download the kernels to your local Hugging Face
 | 
				
			||||||
 | 
					Hub cache.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The pre-downloaded kernels are used by the `get_locked_kernel` function.
 | 
				
			||||||
 | 
					`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
 | 
				
			||||||
 | 
					want kernel loading to error when a kernel is not pre-downloaded, you can use
 | 
				
			||||||
 | 
					the `load_kernel` function instead:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```python
 | 
				
			||||||
 | 
					from kernels import load_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					activation = load_kernel("kernels-community/activation")
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hf_kernels import get_kernel
 | 
					from kernels import get_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
print("Starting examples/basic.py demo")
 | 
					print("Starting examples/basic.py demo")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							@ -0,0 +1,134 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  "nodes": {
 | 
				
			||||||
 | 
					    "flake-compat": {
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1733328505,
 | 
				
			||||||
 | 
					        "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
 | 
				
			||||||
 | 
					        "owner": "edolstra",
 | 
				
			||||||
 | 
					        "repo": "flake-compat",
 | 
				
			||||||
 | 
					        "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "edolstra",
 | 
				
			||||||
 | 
					        "repo": "flake-compat",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "flake-utils": {
 | 
				
			||||||
 | 
					      "inputs": {
 | 
				
			||||||
 | 
					        "systems": "systems"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1731533236,
 | 
				
			||||||
 | 
					        "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
 | 
				
			||||||
 | 
					        "owner": "numtide",
 | 
				
			||||||
 | 
					        "repo": "flake-utils",
 | 
				
			||||||
 | 
					        "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "numtide",
 | 
				
			||||||
 | 
					        "repo": "flake-utils",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "flake-utils_2": {
 | 
				
			||||||
 | 
					      "inputs": {
 | 
				
			||||||
 | 
					        "systems": "systems_2"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1731533236,
 | 
				
			||||||
 | 
					        "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
 | 
				
			||||||
 | 
					        "owner": "numtide",
 | 
				
			||||||
 | 
					        "repo": "flake-utils",
 | 
				
			||||||
 | 
					        "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "numtide",
 | 
				
			||||||
 | 
					        "repo": "flake-utils",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "nixpkgs": {
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1737453259,
 | 
				
			||||||
 | 
					        "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
 | 
				
			||||||
 | 
					        "owner": "danieldk",
 | 
				
			||||||
 | 
					        "repo": "nixpkgs",
 | 
				
			||||||
 | 
					        "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "danieldk",
 | 
				
			||||||
 | 
					        "ref": "outlines-v0.1.4-tgi",
 | 
				
			||||||
 | 
					        "repo": "nixpkgs",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "root": {
 | 
				
			||||||
 | 
					      "inputs": {
 | 
				
			||||||
 | 
					        "flake-utils": "flake-utils",
 | 
				
			||||||
 | 
					        "nixpkgs": [
 | 
				
			||||||
 | 
					          "tgi-nix",
 | 
				
			||||||
 | 
					          "nixpkgs"
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        "tgi-nix": "tgi-nix"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "systems": {
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1681028828,
 | 
				
			||||||
 | 
					        "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
 | 
				
			||||||
 | 
					        "owner": "nix-systems",
 | 
				
			||||||
 | 
					        "repo": "default",
 | 
				
			||||||
 | 
					        "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "nix-systems",
 | 
				
			||||||
 | 
					        "repo": "default",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "systems_2": {
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1681028828,
 | 
				
			||||||
 | 
					        "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
 | 
				
			||||||
 | 
					        "owner": "nix-systems",
 | 
				
			||||||
 | 
					        "repo": "default",
 | 
				
			||||||
 | 
					        "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "nix-systems",
 | 
				
			||||||
 | 
					        "repo": "default",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "tgi-nix": {
 | 
				
			||||||
 | 
					      "inputs": {
 | 
				
			||||||
 | 
					        "flake-compat": "flake-compat",
 | 
				
			||||||
 | 
					        "flake-utils": "flake-utils_2",
 | 
				
			||||||
 | 
					        "nixpkgs": "nixpkgs"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "locked": {
 | 
				
			||||||
 | 
					        "lastModified": 1741617161,
 | 
				
			||||||
 | 
					        "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
 | 
				
			||||||
 | 
					        "owner": "huggingface",
 | 
				
			||||||
 | 
					        "repo": "text-generation-inference-nix",
 | 
				
			||||||
 | 
					        "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "original": {
 | 
				
			||||||
 | 
					        "owner": "huggingface",
 | 
				
			||||||
 | 
					        "ref": "kernels-0.2.0",
 | 
				
			||||||
 | 
					        "repo": "text-generation-inference-nix",
 | 
				
			||||||
 | 
					        "type": "github"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  "root": "root",
 | 
				
			||||||
 | 
					  "version": 7
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
				
			|||||||
 | 
					{
 | 
				
			||||||
 | 
					  inputs = {
 | 
				
			||||||
 | 
					    tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
 | 
				
			||||||
 | 
					    nixpkgs.follows = "tgi-nix/nixpkgs";
 | 
				
			||||||
 | 
					    flake-utils.url = "github:numtide/flake-utils";
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					  outputs =
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      self,
 | 
				
			||||||
 | 
					      nixpkgs,
 | 
				
			||||||
 | 
					      flake-utils,
 | 
				
			||||||
 | 
					      tgi-nix,
 | 
				
			||||||
 | 
					    }:
 | 
				
			||||||
 | 
					    flake-utils.lib.eachDefaultSystem (
 | 
				
			||||||
 | 
					      system:
 | 
				
			||||||
 | 
					      let
 | 
				
			||||||
 | 
					        pkgs = import nixpkgs {
 | 
				
			||||||
 | 
					          inherit system;
 | 
				
			||||||
 | 
					          inherit (tgi-nix.lib) config;
 | 
				
			||||||
 | 
					          overlays = [
 | 
				
			||||||
 | 
					            tgi-nix.overlays.default
 | 
				
			||||||
 | 
					          ];
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					      in
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        formatter = pkgs.nixfmt-rfc-style;
 | 
				
			||||||
 | 
					        devShells = with pkgs; rec {
 | 
				
			||||||
 | 
					          default = mkShell {
 | 
				
			||||||
 | 
					            buildInputs =
 | 
				
			||||||
 | 
					              [
 | 
				
			||||||
 | 
					                black
 | 
				
			||||||
 | 
					                mypy
 | 
				
			||||||
 | 
					                pyright
 | 
				
			||||||
 | 
					                ruff
 | 
				
			||||||
 | 
					              ]
 | 
				
			||||||
 | 
					              ++ (with python3.pkgs; [
 | 
				
			||||||
 | 
					                huggingface-hub
 | 
				
			||||||
 | 
					                pytest
 | 
				
			||||||
 | 
					                pytest-benchmark
 | 
				
			||||||
 | 
					                torch
 | 
				
			||||||
 | 
					                venvShellHook
 | 
				
			||||||
 | 
					              ]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            venvDir = "./.venv";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            postVenvCreation = ''
 | 
				
			||||||
 | 
					              unset SOURCE_DATE_EPOCH
 | 
				
			||||||
 | 
					              ( python -m pip install --no-build-isolation --no-dependencies -e . )
 | 
				
			||||||
 | 
					            '';
 | 
				
			||||||
 | 
					          };
 | 
				
			||||||
 | 
					        };
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -1,20 +1,20 @@
 | 
				
			|||||||
[project]
 | 
					[project]
 | 
				
			||||||
name = "hf-kernels"
 | 
					name = "kernels"
 | 
				
			||||||
version = "0.1.5"
 | 
					version = "0.4.4"
 | 
				
			||||||
description = "Download cuda kernels"
 | 
					description = "Download compute kernels"
 | 
				
			||||||
authors = [
 | 
					authors = [
 | 
				
			||||||
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
 | 
					  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
 | 
				
			||||||
  { name = "Daniel de Kok", email = "daniel@huggingface.co" },
 | 
					  { name = "Daniel de Kok", email = "daniel@huggingface.co" },
 | 
				
			||||||
  { name = "David Holtz", email = "david@huggingface.co" },
 | 
					  { name = "David Holtz", email = "david@huggingface.co" },
 | 
				
			||||||
  { name = "Nicolas Patry", email = "nicolas@huggingface.co" },
 | 
					  { name = "Nicolas Patry", email = "nicolas@huggingface.co" },
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					license = { text = "Apache-2.0" }
 | 
				
			||||||
readme = "README.md"
 | 
					readme = "README.md"
 | 
				
			||||||
requires-python = ">= 3.9"
 | 
					requires-python = ">= 3.9"
 | 
				
			||||||
dependencies = [
 | 
					dependencies = [
 | 
				
			||||||
  "huggingface-hub>=0.26.3",
 | 
					  "huggingface_hub>=0.26.0,<1.0",
 | 
				
			||||||
  "packaging>=24.2",
 | 
					  "packaging>=20.0",
 | 
				
			||||||
  "tomli>=2.0.1; python_version<'3.11'",
 | 
					  "tomli>=2.0; python_version<'3.11'",
 | 
				
			||||||
  "torch>=2.4",
 | 
					 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[build-system]
 | 
					[build-system]
 | 
				
			||||||
@ -23,18 +23,46 @@ build-backend = "setuptools.build_meta"
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
[dependency-groups]
 | 
					[dependency-groups]
 | 
				
			||||||
dev = [
 | 
					dev = [
 | 
				
			||||||
 | 
					  "mypy == 1.14.1",
 | 
				
			||||||
  "pytest >=8",
 | 
					  "pytest >=8",
 | 
				
			||||||
  # Whatever version is compatible with pytest.
 | 
					  # Whatever version is compatible with pytest.
 | 
				
			||||||
  "pytest-benchmark",
 | 
					  "pytest-benchmark",
 | 
				
			||||||
 | 
					  "torch >=2.5",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[project.optional-dependencies]
 | 
				
			||||||
 | 
					torch = ["torch"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[project.scripts]
 | 
					[project.scripts]
 | 
				
			||||||
hf-kernels = "hf_kernels.cli:main"
 | 
					kernels = "kernels.cli:main"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[project.entry-points."egg_info.writers"]
 | 
					[project.entry-points."egg_info.writers"]
 | 
				
			||||||
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"
 | 
					"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[build-system]
 | 
					
 | 
				
			||||||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
 | 
					[tool.ruff]
 | 
				
			||||||
#build-backend = "hf_kernels.build"
 | 
					exclude = [
 | 
				
			||||||
#backend-path = ["src"]
 | 
					  ".eggs",
 | 
				
			||||||
 | 
					  ".git",
 | 
				
			||||||
 | 
					  ".git-rewrite",
 | 
				
			||||||
 | 
					  ".hg",
 | 
				
			||||||
 | 
					  ".mypy_cache",
 | 
				
			||||||
 | 
					  ".nox",
 | 
				
			||||||
 | 
					  ".pants.d",
 | 
				
			||||||
 | 
					  ".pytype",
 | 
				
			||||||
 | 
					  ".ruff_cache",
 | 
				
			||||||
 | 
					  ".svn",
 | 
				
			||||||
 | 
					  ".tox",
 | 
				
			||||||
 | 
					  ".venv",
 | 
				
			||||||
 | 
					  ".venv*",
 | 
				
			||||||
 | 
					  "__pypackages__",
 | 
				
			||||||
 | 
					  "_build",
 | 
				
			||||||
 | 
					  "build",
 | 
				
			||||||
 | 
					  "dist",
 | 
				
			||||||
 | 
					  "venv",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					line-length = 119
 | 
				
			||||||
 | 
					# Ignored rules:
 | 
				
			||||||
 | 
					# "E501" -> line length violation
 | 
				
			||||||
 | 
					lint.ignore = ["E501"]
 | 
				
			||||||
 | 
					lint.select = ["E", "F", "I", "W"]
 | 
				
			||||||
 | 
				
			|||||||
@ -1,3 +0,0 @@
 | 
				
			|||||||
from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
 | 
					 | 
				
			||||||
@ -1,144 +0,0 @@
 | 
				
			|||||||
"""
 | 
					 | 
				
			||||||
Python shims for the PEP 517 and PEP 660 build backend.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
Major imports in this module are required to be lazy:
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
$ hyperfine \
 | 
					 | 
				
			||||||
     "/usr/bin/python3 -c \"print('hi')\"" \
 | 
					 | 
				
			||||||
     "/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\""
 | 
					 | 
				
			||||||
Base: Time (mean ± σ):      11.0 ms ±   1.7 ms    [User: 8.5 ms, System: 2.5 ms]
 | 
					 | 
				
			||||||
With import: Time (mean ± σ):      15.2 ms ±   2.0 ms    [User: 12.3 ms, System: 2.9 ms]
 | 
					 | 
				
			||||||
Base 1.38 ± 0.28 times faster than with import
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
The same thing goes for the typing module, so we use Python 3.10 type annotations that
 | 
					 | 
				
			||||||
don't require importing typing but then quote them so earlier Python version ignore
 | 
					 | 
				
			||||||
them while IDEs and type checker can see through the quotes.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from hf_kernels.compat import tomllib
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
TYPE_CHECKING = False
 | 
					 | 
				
			||||||
if TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from collections.abc import Mapping, Sequence  # noqa:I001
 | 
					 | 
				
			||||||
    from typing import Any  # noqa:I001
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
 | 
					 | 
				
			||||||
    import sys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if config_settings:
 | 
					 | 
				
			||||||
        print("Warning: Config settings are not supported", file=sys.stderr)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def call(
 | 
					 | 
				
			||||||
    args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """Invoke a uv subprocess and return the filename from stdout."""
 | 
					 | 
				
			||||||
    import shutil
 | 
					 | 
				
			||||||
    import subprocess
 | 
					 | 
				
			||||||
    import sys
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    warn_config_settings(config_settings)
 | 
					 | 
				
			||||||
    # Unlike `find_uv_bin`, this mechanism must work according to PEP 517
 | 
					 | 
				
			||||||
    import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    cwd = os.getcwd()
 | 
					 | 
				
			||||||
    filename = os.path.join(cwd, "pyproject.toml")
 | 
					 | 
				
			||||||
    with open(filename, "rb") as f:
 | 
					 | 
				
			||||||
        data = tomllib.load(f)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for kernel, _ in (
 | 
					 | 
				
			||||||
        data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items()
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        from hf_kernels.utils import install_kernel
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        install_kernel(kernel, revision="main")
 | 
					 | 
				
			||||||
    uv_bin = shutil.which("uv")
 | 
					 | 
				
			||||||
    if uv_bin is None:
 | 
					 | 
				
			||||||
        raise RuntimeError("uv was not properly installed")
 | 
					 | 
				
			||||||
    # Forward stderr, capture stdout for the filename
 | 
					 | 
				
			||||||
    result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE)
 | 
					 | 
				
			||||||
    if result.returncode != 0:
 | 
					 | 
				
			||||||
        sys.exit(result.returncode)
 | 
					 | 
				
			||||||
    # If there was extra stdout, forward it (there should not be extra stdout)
 | 
					 | 
				
			||||||
    stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True)
 | 
					 | 
				
			||||||
    sys.stdout.writelines(stdout[:-1])
 | 
					 | 
				
			||||||
    # Fail explicitly instead of an irrelevant stacktrace
 | 
					 | 
				
			||||||
    if not stdout:
 | 
					 | 
				
			||||||
        print("uv subprocess did not return a filename on stdout", file=sys.stderr)
 | 
					 | 
				
			||||||
        sys.exit(1)
 | 
					 | 
				
			||||||
    return stdout[-1].strip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_sdist(
 | 
					 | 
				
			||||||
    sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """PEP 517 hook `build_sdist`."""
 | 
					 | 
				
			||||||
    args = ["build-backend", "build-sdist", sdist_directory]
 | 
					 | 
				
			||||||
    return call(args, config_settings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_wheel(
 | 
					 | 
				
			||||||
    wheel_directory: str,
 | 
					 | 
				
			||||||
    config_settings: "Mapping[Any, Any] | None" = None,
 | 
					 | 
				
			||||||
    metadata_directory: "str | None" = None,
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """PEP 517 hook `build_wheel`."""
 | 
					 | 
				
			||||||
    args = ["build-backend", "build-wheel", wheel_directory]
 | 
					 | 
				
			||||||
    if metadata_directory:
 | 
					 | 
				
			||||||
        args.extend(["--metadata-directory", metadata_directory])
 | 
					 | 
				
			||||||
    return call(args, config_settings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_requires_for_build_sdist(
 | 
					 | 
				
			||||||
    config_settings: "Mapping[Any, Any] | None" = None,
 | 
					 | 
				
			||||||
) -> "Sequence[str]":
 | 
					 | 
				
			||||||
    """PEP 517 hook `get_requires_for_build_sdist`."""
 | 
					 | 
				
			||||||
    warn_config_settings(config_settings)
 | 
					 | 
				
			||||||
    return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_requires_for_build_wheel(
 | 
					 | 
				
			||||||
    config_settings: "Mapping[Any, Any] | None" = None,
 | 
					 | 
				
			||||||
) -> "Sequence[str]":
 | 
					 | 
				
			||||||
    """PEP 517 hook `get_requires_for_build_wheel`."""
 | 
					 | 
				
			||||||
    warn_config_settings(config_settings)
 | 
					 | 
				
			||||||
    return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_metadata_for_build_wheel(
 | 
					 | 
				
			||||||
    metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """PEP 517 hook `prepare_metadata_for_build_wheel`."""
 | 
					 | 
				
			||||||
    args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory]
 | 
					 | 
				
			||||||
    return call(args, config_settings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_editable(
 | 
					 | 
				
			||||||
    wheel_directory: str,
 | 
					 | 
				
			||||||
    config_settings: "Mapping[Any, Any] | None" = None,
 | 
					 | 
				
			||||||
    metadata_directory: "str | None" = None,
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """PEP 660 hook `build_editable`."""
 | 
					 | 
				
			||||||
    args = ["build-backend", "build-editable", wheel_directory]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if metadata_directory:
 | 
					 | 
				
			||||||
        args.extend(["--metadata-directory", metadata_directory])
 | 
					 | 
				
			||||||
    return call(args, config_settings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_requires_for_build_editable(
 | 
					 | 
				
			||||||
    config_settings: "Mapping[Any, Any] | None" = None,
 | 
					 | 
				
			||||||
) -> "Sequence[str]":
 | 
					 | 
				
			||||||
    """PEP 660 hook `get_requires_for_build_editable`."""
 | 
					 | 
				
			||||||
    warn_config_settings(config_settings)
 | 
					 | 
				
			||||||
    return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def prepare_metadata_for_build_editable(
 | 
					 | 
				
			||||||
    metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    """PEP 660 hook `prepare_metadata_for_build_editable`."""
 | 
					 | 
				
			||||||
    args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory]
 | 
					 | 
				
			||||||
    return call(args, config_settings)
 | 
					 | 
				
			||||||
@ -1,163 +0,0 @@
 | 
				
			|||||||
import ctypes
 | 
					 | 
				
			||||||
import importlib
 | 
					 | 
				
			||||||
import importlib.metadata
 | 
					 | 
				
			||||||
import inspect
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import platform
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
from importlib.metadata import Distribution
 | 
					 | 
				
			||||||
from types import ModuleType
 | 
					 | 
				
			||||||
from typing import List, Optional
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from huggingface_hub import hf_hub_download, snapshot_download
 | 
					 | 
				
			||||||
from packaging.version import parse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from hf_kernels.compat import tomllib
 | 
					 | 
				
			||||||
from hf_kernels.lockfile import KernelLock
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_variant():
 | 
					 | 
				
			||||||
    import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    torch_version = parse(torch.__version__)
 | 
					 | 
				
			||||||
    cuda_version = parse(torch.version.cuda)
 | 
					 | 
				
			||||||
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
 | 
					 | 
				
			||||||
    cpu = platform.machine()
 | 
					 | 
				
			||||||
    os = platform.system().lower()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def import_from_path(module_name: str, file_path):
 | 
					 | 
				
			||||||
    # We cannot use the module name as-is, after adding it to `sys.modules`,
 | 
					 | 
				
			||||||
    # it would also be used for other imports. So, we make a module name that
 | 
					 | 
				
			||||||
    # depends on the path for it to be unique using the hex-encoded hash of
 | 
					 | 
				
			||||||
    # the path.
 | 
					 | 
				
			||||||
    path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
 | 
					 | 
				
			||||||
    module_name = f"{module_name}_{path_hash}"
 | 
					 | 
				
			||||||
    spec = importlib.util.spec_from_file_location(module_name, file_path)
 | 
					 | 
				
			||||||
    module = importlib.util.module_from_spec(spec)
 | 
					 | 
				
			||||||
    sys.modules[module_name] = module
 | 
					 | 
				
			||||||
    spec.loader.exec_module(module)
 | 
					 | 
				
			||||||
    return module
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
 | 
					 | 
				
			||||||
    package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
 | 
					 | 
				
			||||||
        "torch"
 | 
					 | 
				
			||||||
    ]["name"]
 | 
					 | 
				
			||||||
    repo_path = snapshot_download(
 | 
					 | 
				
			||||||
        repo_id,
 | 
					 | 
				
			||||||
        allow_patterns=f"build/{build_variant()}/*",
 | 
					 | 
				
			||||||
        cache_dir=CACHE_DIR,
 | 
					 | 
				
			||||||
        revision=revision,
 | 
					 | 
				
			||||||
        local_files_only=local_files_only,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    return package_name, f"{repo_path}/build/{build_variant()}"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def install_kernel_all_variants(
 | 
					 | 
				
			||||||
    repo_id: str, revision: str, local_files_only: bool = False
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    snapshot_download(
 | 
					 | 
				
			||||||
        repo_id,
 | 
					 | 
				
			||||||
        allow_patterns="build/*",
 | 
					 | 
				
			||||||
        cache_dir=CACHE_DIR,
 | 
					 | 
				
			||||||
        revision=revision,
 | 
					 | 
				
			||||||
        local_files_only=local_files_only,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
 | 
					 | 
				
			||||||
    with open(
 | 
					 | 
				
			||||||
        hf_hub_download(
 | 
					 | 
				
			||||||
            repo_id,
 | 
					 | 
				
			||||||
            "build.toml",
 | 
					 | 
				
			||||||
            cache_dir=CACHE_DIR,
 | 
					 | 
				
			||||||
            revision=revision,
 | 
					 | 
				
			||||||
            local_files_only=local_files_only,
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
        "rb",
 | 
					 | 
				
			||||||
    ) as f:
 | 
					 | 
				
			||||||
        return tomllib.load(f)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_kernel(repo_id: str, revision: str = "main"):
 | 
					 | 
				
			||||||
    package_name, package_path = install_kernel(repo_id, revision=revision)
 | 
					 | 
				
			||||||
    return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def load_kernel(repo_id: str):
 | 
					 | 
				
			||||||
    """Get a pre-downloaded, locked kernel."""
 | 
					 | 
				
			||||||
    locked_sha = _get_caller_locked_kernel(repo_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if locked_sha is None:
 | 
					 | 
				
			||||||
        raise ValueError(f"Kernel `{repo_id}` is not locked")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    filename = hf_hub_download(
 | 
					 | 
				
			||||||
        repo_id,
 | 
					 | 
				
			||||||
        "build.toml",
 | 
					 | 
				
			||||||
        cache_dir=CACHE_DIR,
 | 
					 | 
				
			||||||
        local_files_only=True,
 | 
					 | 
				
			||||||
        revision=locked_sha,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    with open(filename, "rb") as f:
 | 
					 | 
				
			||||||
        metadata = tomllib.load(f)
 | 
					 | 
				
			||||||
    package_name = metadata["torch"]["name"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    repo_path = os.path.dirname(filename)
 | 
					 | 
				
			||||||
    package_path = f"{repo_path}/build/{build_variant()}"
 | 
					 | 
				
			||||||
    return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
 | 
					 | 
				
			||||||
    """Get a kernel using a lock file."""
 | 
					 | 
				
			||||||
    locked_sha = _get_caller_locked_kernel(repo_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if locked_sha is None:
 | 
					 | 
				
			||||||
        raise ValueError(f"Kernel `{repo_id}` is not locked")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    package_name, package_path = install_kernel(
 | 
					 | 
				
			||||||
        repo_id, locked_sha, local_files_only=local_files_only
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
 | 
					 | 
				
			||||||
    for dist in _get_caller_distributions():
 | 
					 | 
				
			||||||
        lock_json = dist.read_text("hf-kernels.lock")
 | 
					 | 
				
			||||||
        if lock_json is not None:
 | 
					 | 
				
			||||||
            for kernel_lock_json in json.loads(lock_json):
 | 
					 | 
				
			||||||
                kernel_lock = KernelLock.from_json(kernel_lock_json)
 | 
					 | 
				
			||||||
                if kernel_lock.repo_id == repo_id:
 | 
					 | 
				
			||||||
                    return kernel_lock.sha
 | 
					 | 
				
			||||||
    return None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_caller_distributions() -> List[Distribution]:
 | 
					 | 
				
			||||||
    module = _get_caller_module()
 | 
					 | 
				
			||||||
    if module is None:
 | 
					 | 
				
			||||||
        return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Look up all possible distributions that this module could be from.
 | 
					 | 
				
			||||||
    package = module.__name__.split(".")[0]
 | 
					 | 
				
			||||||
    dist_names = importlib.metadata.packages_distributions().get(package)
 | 
					 | 
				
			||||||
    if dist_names is None:
 | 
					 | 
				
			||||||
        return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_caller_module() -> Optional[ModuleType]:
 | 
					 | 
				
			||||||
    stack = inspect.stack()
 | 
					 | 
				
			||||||
    # Get first module in the stack that is not the current module.
 | 
					 | 
				
			||||||
    first_module = inspect.getmodule(stack[0][0])
 | 
					 | 
				
			||||||
    for frame in stack[1:]:
 | 
					 | 
				
			||||||
        module = inspect.getmodule(frame[0])
 | 
					 | 
				
			||||||
        if module is not None and module != first_module:
 | 
					 | 
				
			||||||
            return module
 | 
					 | 
				
			||||||
    return first_module
 | 
					 | 
				
			||||||
							
								
								
									
										29
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					from kernels.layer import (
 | 
				
			||||||
 | 
					    Device,
 | 
				
			||||||
 | 
					    LayerRepository,
 | 
				
			||||||
 | 
					    register_kernel_mapping,
 | 
				
			||||||
 | 
					    replace_kernel_forward_from_hub,
 | 
				
			||||||
 | 
					    use_kernel_forward_from_hub,
 | 
				
			||||||
 | 
					    use_kernel_mapping,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from kernels.utils import (
 | 
				
			||||||
 | 
					    get_kernel,
 | 
				
			||||||
 | 
					    get_locked_kernel,
 | 
				
			||||||
 | 
					    has_kernel,
 | 
				
			||||||
 | 
					    install_kernel,
 | 
				
			||||||
 | 
					    load_kernel,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "get_kernel",
 | 
				
			||||||
 | 
					    "get_locked_kernel",
 | 
				
			||||||
 | 
					    "has_kernel",
 | 
				
			||||||
 | 
					    "load_kernel",
 | 
				
			||||||
 | 
					    "install_kernel",
 | 
				
			||||||
 | 
					    "use_kernel_forward_from_hub",
 | 
				
			||||||
 | 
					    "use_kernel_mapping",
 | 
				
			||||||
 | 
					    "register_kernel_mapping",
 | 
				
			||||||
 | 
					    "replace_kernel_forward_from_hub",
 | 
				
			||||||
 | 
					    "LayerRepository",
 | 
				
			||||||
 | 
					    "Device",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
@ -4,14 +4,14 @@ import json
 | 
				
			|||||||
import sys
 | 
					import sys
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hf_kernels.compat import tomllib
 | 
					from kernels.compat import tomllib
 | 
				
			||||||
from hf_kernels.lockfile import KernelLock, get_kernel_locks
 | 
					from kernels.lockfile import KernelLock, get_kernel_locks
 | 
				
			||||||
from hf_kernels.utils import install_kernel, install_kernel_all_variants
 | 
					from kernels.utils import install_kernel, install_kernel_all_variants
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser(
 | 
					    parser = argparse.ArgumentParser(
 | 
				
			||||||
        prog="hf-kernel", description="Manage compute kernels"
 | 
					        prog="kernel", description="Manage compute kernels"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    subparsers = parser.add_subparsers(required=True)
 | 
					    subparsers = parser.add_subparsers(required=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -41,15 +41,17 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def download_kernels(args):
 | 
					def download_kernels(args):
 | 
				
			||||||
    lock_path = args.project_dir / "hf-kernels.lock"
 | 
					    lock_path = args.project_dir / "kernels.lock"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not lock_path.exists():
 | 
					    if not lock_path.exists():
 | 
				
			||||||
        print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
 | 
					        print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr)
 | 
				
			||||||
        sys.exit(1)
 | 
					        sys.exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(args.project_dir / "hf-kernels.lock", "r") as f:
 | 
					    with open(args.project_dir / "kernels.lock", "r") as f:
 | 
				
			||||||
        lock_json = json.load(f)
 | 
					        lock_json = json.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    all_successful = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for kernel_lock_json in lock_json:
 | 
					    for kernel_lock_json in lock_json:
 | 
				
			||||||
        kernel_lock = KernelLock.from_json(kernel_lock_json)
 | 
					        kernel_lock = KernelLock.from_json(kernel_lock_json)
 | 
				
			||||||
        print(
 | 
					        print(
 | 
				
			||||||
@ -57,9 +59,22 @@ def download_kernels(args):
 | 
				
			|||||||
            file=sys.stderr,
 | 
					            file=sys.stderr,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if args.all_variants:
 | 
					        if args.all_variants:
 | 
				
			||||||
            install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
 | 
					            install_kernel_all_variants(
 | 
				
			||||||
 | 
					                kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            install_kernel(kernel_lock.repo_id, kernel_lock.sha)
 | 
					            try:
 | 
				
			||||||
 | 
					                install_kernel(
 | 
				
			||||||
 | 
					                    kernel_lock.repo_id,
 | 
				
			||||||
 | 
					                    kernel_lock.sha,
 | 
				
			||||||
 | 
					                    variant_locks=kernel_lock.variants,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            except FileNotFoundError as e:
 | 
				
			||||||
 | 
					                print(e, file=sys.stderr)
 | 
				
			||||||
 | 
					                all_successful = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not all_successful:
 | 
				
			||||||
 | 
					        sys.exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def lock_kernels(args):
 | 
					def lock_kernels(args):
 | 
				
			||||||
@ -72,7 +87,7 @@ def lock_kernels(args):
 | 
				
			|||||||
    for kernel, version in kernel_versions.items():
 | 
					    for kernel, version in kernel_versions.items():
 | 
				
			||||||
        all_locks.append(get_kernel_locks(kernel, version))
 | 
					        all_locks.append(get_kernel_locks(kernel, version))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with open(args.project_dir / "hf-kernels.lock", "w") as f:
 | 
					    with open(args.project_dir / "kernels.lock", "w") as f:
 | 
				
			||||||
        json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
 | 
					        json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										264
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										264
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,264 @@
 | 
				
			|||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from contextvars import ContextVar
 | 
				
			||||||
 | 
					from copy import deepcopy
 | 
				
			||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
 | 
					from typing import TYPE_CHECKING, Dict, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .utils import get_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass(frozen=True)
 | 
				
			||||||
 | 
					class Device:
 | 
				
			||||||
 | 
					    type: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # In the future we might add compute capabilities, etc.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __eq__(self, other):
 | 
				
			||||||
 | 
					        return isinstance(other, Device) and self.type == other.type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __hash__(self):
 | 
				
			||||||
 | 
					        return hash(self.type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class LayerRepository:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Repository and name of a layer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    layer_name: str = field(
 | 
				
			||||||
 | 
					        metadata={"help": "The name of the layer in the kernel repository."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
 | 
				
			||||||
 | 
					    revision: str = field(
 | 
				
			||||||
 | 
					        default="main", metadata={"help": "The revision of the layer."}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __eq__(self, other):
 | 
				
			||||||
 | 
					        return (
 | 
				
			||||||
 | 
					            isinstance(other, LayerRepository)
 | 
				
			||||||
 | 
					            and self.layer_name == other.layer_name
 | 
				
			||||||
 | 
					            and self.repo_id == other.repo_id
 | 
				
			||||||
 | 
					            and self.revision == other.revision
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __hash__(self):
 | 
				
			||||||
 | 
					        return hash((self.layer_name, self.repo_id, self.revision))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
 | 
				
			||||||
 | 
					    "_KERNEL_MAPPING", default={}
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def use_kernel_mapping(
 | 
				
			||||||
 | 
					    mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
 | 
				
			||||||
 | 
					    *,
 | 
				
			||||||
 | 
					    inherit_mapping: bool = True,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Context manager that sets a mapping for a duration of the context.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    When `inherit_mapping` is set to `True` the current mapping will be
 | 
				
			||||||
 | 
					    extended by `mapping` inside the context. If it is `False`, only
 | 
				
			||||||
 | 
					    `mapping` is used inside the context.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ContextManager:
 | 
				
			||||||
 | 
					        def __enter__(self):
 | 
				
			||||||
 | 
					            # Mappings always stack on previous mappings.
 | 
				
			||||||
 | 
					            if inherit_mapping:
 | 
				
			||||||
 | 
					                self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get()))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                self.token = _KERNEL_MAPPING.set({})
 | 
				
			||||||
 | 
					            register_kernel_mapping(mapping)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def __exit__(self, exc_type, exc_value, traceback):
 | 
				
			||||||
 | 
					            _KERNEL_MAPPING.reset(self.token)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ContextManager()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def register_kernel_mapping(
 | 
				
			||||||
 | 
					    mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
 | 
				
			||||||
 | 
					    This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
 | 
				
			||||||
 | 
					    Exemple usage:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ```python
 | 
				
			||||||
 | 
					    from kernels import LayerRepository, register_kernel_mapping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kernel_layer_mapping = {
 | 
				
			||||||
 | 
					      "LlamaRMSNorm": {
 | 
				
			||||||
 | 
					          "cuda": LayerRepository(
 | 
				
			||||||
 | 
					              repo_id="kernels-community/activation",
 | 
				
			||||||
 | 
					              layer_name="RmsNorm",
 | 
				
			||||||
 | 
					              revision="layers",
 | 
				
			||||||
 | 
					          ),
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    register_kernel_mapping(kernel_layer_mapping)
 | 
				
			||||||
 | 
					    ```
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    # Merge with existing mappings.
 | 
				
			||||||
 | 
					    for new_kernel, new_device_repos in mapping.items():
 | 
				
			||||||
 | 
					        device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
 | 
				
			||||||
 | 
					        for new_device, new_repo in new_device_repos.items():
 | 
				
			||||||
 | 
					            if isinstance(new_device, str):
 | 
				
			||||||
 | 
					                device_repo[Device(type=new_device)] = new_repo
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                device_repo[new_device] = new_repo
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Replace the forward function of a layer using a layer from the kernel hub.
 | 
				
			||||||
 | 
					    This function monkeypatches a layer, replacing the `forward` method
 | 
				
			||||||
 | 
					    of the layer with that of a layer from the hub. The replacement is done
 | 
				
			||||||
 | 
					    when a layer matching `layer_name` and device type is registered through
 | 
				
			||||||
 | 
					    `register_layer_mapping`. The device type is inferred from the first
 | 
				
			||||||
 | 
					    argument to `forward`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fallback_forward = cls.forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cached_layer: Dict[LayerRepository, nn.Module] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x, *args, **kwargs):
 | 
				
			||||||
 | 
					        if _DISABLE_KERNEL_MAPPING:
 | 
				
			||||||
 | 
					            return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        needs_backward = self.training
 | 
				
			||||||
 | 
					        kernel = _KERNEL_MAPPING.get().get(layer_name)
 | 
				
			||||||
 | 
					        if kernel is None:
 | 
				
			||||||
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
					                "\n"
 | 
				
			||||||
 | 
					                f"No kernel mapping found for layer `{layer_name}`. "
 | 
				
			||||||
 | 
					                f"Check if the layer name matches one of the kernels in the mapping or add the kernel "
 | 
				
			||||||
 | 
					                f"you want to use to the mapping. Defaulting to original forward implementation."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            if not use_fallback:
 | 
				
			||||||
 | 
					                raise ValueError(f"No layer mapping for `{layer_name}`")
 | 
				
			||||||
 | 
					            return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        device = getattr(x, "device", None)
 | 
				
			||||||
 | 
					        if device is None:
 | 
				
			||||||
 | 
					            return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        repo = kernel.get(Device(type=device.type))
 | 
				
			||||||
 | 
					        if repo is None:
 | 
				
			||||||
 | 
					            if not use_fallback:
 | 
				
			||||||
 | 
					                raise ValueError(
 | 
				
			||||||
 | 
					                    f"No layer mapping for `{layer_name}` with device type `{device.type}`"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Short-circuit if we already loaded the layer.
 | 
				
			||||||
 | 
					        layer = cached_layer.get(repo, None)
 | 
				
			||||||
 | 
					        if layer is not None:
 | 
				
			||||||
 | 
					            if needs_backward and not getattr(layer, "has_backward", True):
 | 
				
			||||||
 | 
					                return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					            return layer.forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        layer = _get_kernel_layer(
 | 
				
			||||||
 | 
					            repo_id=repo.repo_id,
 | 
				
			||||||
 | 
					            layer_name=repo.layer_name,
 | 
				
			||||||
 | 
					            revision=repo.revision,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # We have to validate against the original signature.
 | 
				
			||||||
 | 
					        orig_forward = cls.forward
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            cls.forward = fallback_forward
 | 
				
			||||||
 | 
					            _validate_layer(check_cls=cls, cls=layer)
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            cls.forward = orig_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cached_layer[repo] = layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if needs_backward and not getattr(layer, "has_backward", True):
 | 
				
			||||||
 | 
					            return fallback_forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					        return layer.forward(self, x, *args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cls.forward = forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Replace the forward function of a layer using a layer from the kernel hub.
 | 
				
			||||||
 | 
					    This decorator can be applied to a layer and replaces the forward method
 | 
				
			||||||
 | 
					    of the layer with that of a layer from the hub. The replacement is done
 | 
				
			||||||
 | 
					    when a layer matching `layer_name` and device type is registered through
 | 
				
			||||||
 | 
					    `register_layer_mapping`. The device type is inferred from the first
 | 
				
			||||||
 | 
					    argument to `forward`.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def decorator(cls):
 | 
				
			||||||
 | 
					        replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
 | 
				
			||||||
 | 
					        return cls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return decorator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
 | 
				
			||||||
 | 
					    """Get a layer from a kernel."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kernel = get_kernel(repo_id, revision=revision)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if getattr(kernel, "layers", None) is None:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            f"Kernel `{repo_id}` at revision `{revision}` does not define any layers."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    layer = getattr(kernel.layers, layer_name, None)
 | 
				
			||||||
 | 
					    if layer is None:
 | 
				
			||||||
 | 
					        raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.")
 | 
				
			||||||
 | 
					    return layer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _validate_layer(*, check_cls, cls):
 | 
				
			||||||
 | 
					    # The layer must have at least have the following properties: (1) it
 | 
				
			||||||
 | 
					    # must be stateless; (2) the forward signature should correspond to
 | 
				
			||||||
 | 
					    # the signature it is replacing; (3) forward should not call other
 | 
				
			||||||
 | 
					    # methods.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from torch import nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not issubclass(cls, nn.Module):
 | 
				
			||||||
 | 
					        raise TypeError(f"Layer `{cls}` is not a Torch layer.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # We verify statelessness by checking that the does not have its own
 | 
				
			||||||
 | 
					    # constructor (since the constructor could add member variables)...
 | 
				
			||||||
 | 
					    if cls.__init__ is not nn.Module.__init__:
 | 
				
			||||||
 | 
					        raise TypeError("Layer must not override nn.Module constructor.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ... or predefined member variables.
 | 
				
			||||||
 | 
					    torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
 | 
				
			||||||
 | 
					    cls_members = {name for name, _ in inspect.getmembers(cls)}
 | 
				
			||||||
 | 
					    difference = cls_members - torch_module_members
 | 
				
			||||||
 | 
					    if difference != set() and difference != {"has_backward"}:
 | 
				
			||||||
 | 
					        raise TypeError("Layer must not contain additional members.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Check whether the forward signatures are similar.
 | 
				
			||||||
 | 
					    params = inspect.signature(cls.forward).parameters
 | 
				
			||||||
 | 
					    ref_params = inspect.signature(check_cls.forward).parameters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if len(params) != len(ref_params):
 | 
				
			||||||
 | 
					        raise TypeError(
 | 
				
			||||||
 | 
					            "Forward signature does not match: different number of arguments."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for param, ref_param in zip(params.values(), ref_params.values()):
 | 
				
			||||||
 | 
					        if param.kind != ref_param.kind:
 | 
				
			||||||
 | 
					            raise TypeError(
 | 
				
			||||||
 | 
					                f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
@ -1,33 +1,37 @@
 | 
				
			|||||||
 | 
					import hashlib
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Dict, List
 | 
					from typing import Dict, List, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from huggingface_hub import HfApi
 | 
					from huggingface_hub import HfApi
 | 
				
			||||||
 | 
					from huggingface_hub.hf_api import GitRefInfo
 | 
				
			||||||
from packaging.specifiers import SpecifierSet
 | 
					from packaging.specifiers import SpecifierSet
 | 
				
			||||||
from packaging.version import InvalidVersion, Version
 | 
					from packaging.version import InvalidVersion, Version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from hf_kernels.compat import tomllib
 | 
					from kernels.compat import tomllib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class FileLock:
 | 
					class VariantLock:
 | 
				
			||||||
    filename: str
 | 
					    hash: str
 | 
				
			||||||
    blob_id: str
 | 
					    hash_type: str = "git_lfs_concat"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class KernelLock:
 | 
					class KernelLock:
 | 
				
			||||||
    repo_id: str
 | 
					    repo_id: str
 | 
				
			||||||
    sha: str
 | 
					    sha: str
 | 
				
			||||||
    files: List[FileLock]
 | 
					    variants: Dict[str, VariantLock]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def from_json(cls, o: Dict):
 | 
					    def from_json(cls, o: Dict):
 | 
				
			||||||
        files = [FileLock(**f) for f in o["files"]]
 | 
					        variants = {
 | 
				
			||||||
        return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
 | 
					            variant: VariantLock(**lock) for variant, lock in o["variants"].items()
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_available_versions(repo_id: str):
 | 
					def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
 | 
				
			||||||
    """Get kernel versions that are available in the repository."""
 | 
					    """Get kernel versions that are available in the repository."""
 | 
				
			||||||
    versions = {}
 | 
					    versions = {}
 | 
				
			||||||
    for tag in HfApi().list_repo_refs(repo_id).tags:
 | 
					    for tag in HfApi().list_repo_refs(repo_id).tags:
 | 
				
			||||||
@ -41,7 +45,7 @@ def _get_available_versions(repo_id: str):
 | 
				
			|||||||
    return versions
 | 
					    return versions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_kernel_locks(repo_id: str, version_spec: str):
 | 
					def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Get the locks for a kernel with the given version spec.
 | 
					    Get the locks for a kernel with the given version spec.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -72,31 +76,55 @@ def get_kernel_locks(repo_id: str, version_spec: str):
 | 
				
			|||||||
            f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
 | 
					            f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    file_locks = []
 | 
					    variant_files: Dict[str, List[Tuple[bytes, str]]] = {}
 | 
				
			||||||
    for sibling in r.siblings:
 | 
					    for sibling in r.siblings:
 | 
				
			||||||
        if sibling.rfilename.startswith("build/torch"):
 | 
					        if sibling.rfilename.startswith("build/torch"):
 | 
				
			||||||
            if sibling.blob_id is None:
 | 
					            if sibling.blob_id is None:
 | 
				
			||||||
                raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
 | 
					                raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            file_locks.append(
 | 
					            path = Path(sibling.rfilename)
 | 
				
			||||||
                FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
 | 
					            variant = path.parts[1]
 | 
				
			||||||
            )
 | 
					            filename = Path(*path.parts[2:])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
 | 
					            hash = sibling.lfs.sha256 if sibling.lfs is not None else sibling.blob_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            files = variant_files.setdefault(variant, [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Encode as posix for consistent slash handling, then encode
 | 
				
			||||||
 | 
					            # as utf-8 for byte-wise sorting later.
 | 
				
			||||||
 | 
					            files.append((filename.as_posix().encode("utf-8"), hash))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    variant_locks = {}
 | 
				
			||||||
 | 
					    for variant, files in variant_files.items():
 | 
				
			||||||
 | 
					        m = hashlib.sha256()
 | 
				
			||||||
 | 
					        for filename_bytes, hash in sorted(files):
 | 
				
			||||||
 | 
					            # Filename as bytes.
 | 
				
			||||||
 | 
					            m.update(filename_bytes)
 | 
				
			||||||
 | 
					            # Git blob or LFS file hash as bytes.
 | 
				
			||||||
 | 
					            m.update(bytes.fromhex(hash))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        variant_locks[variant] = VariantLock(hash=f"sha256-{m.hexdigest()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return KernelLock(repo_id=repo_id, sha=r.sha, variants=variant_locks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def write_egg_lockfile(cmd, basename, filename):
 | 
					def write_egg_lockfile(cmd, basename, filename):
 | 
				
			||||||
    import logging
 | 
					    import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cwd = Path.cwd()
 | 
					    cwd = Path.cwd()
 | 
				
			||||||
    with open(cwd / "pyproject.toml", "rb") as f:
 | 
					    pyproject_path = cwd / "pyproject.toml"
 | 
				
			||||||
 | 
					    if not pyproject_path.exists():
 | 
				
			||||||
 | 
					        # Nothing to do if the project doesn't have pyproject.toml.
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with open(pyproject_path, "rb") as f:
 | 
				
			||||||
        data = tomllib.load(f)
 | 
					        data = tomllib.load(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
 | 
					    kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
 | 
				
			||||||
    if kernel_versions is None:
 | 
					    if kernel_versions is None:
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lock_path = cwd / "hf-kernels.lock"
 | 
					    lock_path = cwd / "kernels.lock"
 | 
				
			||||||
    if not lock_path.exists():
 | 
					    if not lock_path.exists():
 | 
				
			||||||
        logging.warning(f"Lock file {lock_path} does not exist")
 | 
					        logging.warning(f"Lock file {lock_path} does not exist")
 | 
				
			||||||
        # Ensure that the file gets deleted in editable installs.
 | 
					        # Ensure that the file gets deleted in editable installs.
 | 
				
			||||||
							
								
								
									
										348
									
								
								src/kernels/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										348
									
								
								src/kernels/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,348 @@
 | 
				
			|||||||
 | 
					import ctypes
 | 
				
			||||||
 | 
					import hashlib
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
 | 
					import importlib.metadata
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import platform
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from importlib.metadata import Distribution
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					from types import ModuleType
 | 
				
			||||||
 | 
					from typing import Dict, List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from huggingface_hub import file_exists, snapshot_download
 | 
				
			||||||
 | 
					from packaging.version import parse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from kernels.lockfile import KernelLock, VariantLock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_cache_dir() -> Optional[str]:
 | 
				
			||||||
 | 
					    """Returns the kernels cache directory."""
 | 
				
			||||||
 | 
					    cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
 | 
				
			||||||
 | 
					    if cache_dir is not None:
 | 
				
			||||||
 | 
					        logging.warning(
 | 
				
			||||||
 | 
					            "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return cache_dir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return os.environ.get("KERNELS_CACHE", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CACHE_DIR: Optional[str] = _get_cache_dir()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_variant() -> str:
 | 
				
			||||||
 | 
					    import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if torch.version.cuda is not None:
 | 
				
			||||||
 | 
					        cuda_version = parse(torch.version.cuda)
 | 
				
			||||||
 | 
					        compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
 | 
				
			||||||
 | 
					    elif torch.version.hip is not None:
 | 
				
			||||||
 | 
					        rocm_version = parse(torch.version.hip.split("-")[0])
 | 
				
			||||||
 | 
					        compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    torch_version = parse(torch.__version__)
 | 
				
			||||||
 | 
					    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
 | 
				
			||||||
 | 
					    cpu = platform.machine()
 | 
				
			||||||
 | 
					    os = platform.system().lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def universal_build_variant() -> str:
 | 
				
			||||||
 | 
					    # Once we support other frameworks, detection goes here.
 | 
				
			||||||
 | 
					    return "torch-universal"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def import_from_path(module_name: str, file_path: Path) -> ModuleType:
 | 
				
			||||||
 | 
					    # We cannot use the module name as-is, after adding it to `sys.modules`,
 | 
				
			||||||
 | 
					    # it would also be used for other imports. So, we make a module name that
 | 
				
			||||||
 | 
					    # depends on the path for it to be unique using the hex-encoded hash of
 | 
				
			||||||
 | 
					    # the path.
 | 
				
			||||||
 | 
					    path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
 | 
				
			||||||
 | 
					    module_name = f"{module_name}_{path_hash}"
 | 
				
			||||||
 | 
					    spec = importlib.util.spec_from_file_location(module_name, file_path)
 | 
				
			||||||
 | 
					    if spec is None:
 | 
				
			||||||
 | 
					        raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
 | 
				
			||||||
 | 
					    module = importlib.util.module_from_spec(spec)
 | 
				
			||||||
 | 
					    if module is None:
 | 
				
			||||||
 | 
					        raise ImportError(f"Cannot load module {module_name} from spec")
 | 
				
			||||||
 | 
					    sys.modules[module_name] = module
 | 
				
			||||||
 | 
					    spec.loader.exec_module(module)  # type: ignore
 | 
				
			||||||
 | 
					    return module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def install_kernel(
 | 
				
			||||||
 | 
					    repo_id: str,
 | 
				
			||||||
 | 
					    revision: str,
 | 
				
			||||||
 | 
					    local_files_only: bool = False,
 | 
				
			||||||
 | 
					    variant_locks: Optional[Dict[str, VariantLock]] = None,
 | 
				
			||||||
 | 
					) -> Tuple[str, Path]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Download a kernel for the current environment to the cache.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The output path is validated againt `hash` when set.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    package_name = package_name_from_repo_id(repo_id)
 | 
				
			||||||
 | 
					    variant = build_variant()
 | 
				
			||||||
 | 
					    universal_variant = universal_build_variant()
 | 
				
			||||||
 | 
					    repo_path = Path(
 | 
				
			||||||
 | 
					        snapshot_download(
 | 
				
			||||||
 | 
					            repo_id,
 | 
				
			||||||
 | 
					            allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
 | 
				
			||||||
 | 
					            cache_dir=CACHE_DIR,
 | 
				
			||||||
 | 
					            revision=revision,
 | 
				
			||||||
 | 
					            local_files_only=local_files_only,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    variant_path = repo_path / "build" / variant
 | 
				
			||||||
 | 
					    universal_variant_path = repo_path / "build" / universal_variant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not variant_path.exists() and universal_variant_path.exists():
 | 
				
			||||||
 | 
					        # Fall back to universal variant.
 | 
				
			||||||
 | 
					        variant = universal_variant
 | 
				
			||||||
 | 
					        variant_path = universal_variant_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if variant_locks is not None:
 | 
				
			||||||
 | 
					        variant_lock = variant_locks.get(variant)
 | 
				
			||||||
 | 
					        if variant_lock is None:
 | 
				
			||||||
 | 
					            raise ValueError(f"No lock found for build variant: {variant}")
 | 
				
			||||||
 | 
					        validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    module_init_path = variant_path / package_name / "__init__.py"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not os.path.exists(module_init_path):
 | 
				
			||||||
 | 
					        raise FileNotFoundError(
 | 
				
			||||||
 | 
					            f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return package_name, variant_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def install_kernel_all_variants(
 | 
				
			||||||
 | 
					    repo_id: str,
 | 
				
			||||||
 | 
					    revision: str,
 | 
				
			||||||
 | 
					    local_files_only: bool = False,
 | 
				
			||||||
 | 
					    variant_locks: Optional[Dict[str, VariantLock]] = None,
 | 
				
			||||||
 | 
					) -> Path:
 | 
				
			||||||
 | 
					    repo_path = Path(
 | 
				
			||||||
 | 
					        snapshot_download(
 | 
				
			||||||
 | 
					            repo_id,
 | 
				
			||||||
 | 
					            allow_patterns="build/*",
 | 
				
			||||||
 | 
					            cache_dir=CACHE_DIR,
 | 
				
			||||||
 | 
					            revision=revision,
 | 
				
			||||||
 | 
					            local_files_only=local_files_only,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if variant_locks is not None:
 | 
				
			||||||
 | 
					        for entry in (repo_path / "build").iterdir():
 | 
				
			||||||
 | 
					            variant = entry.parts[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            variant_lock = variant_locks.get(variant)
 | 
				
			||||||
 | 
					            if variant_lock is None:
 | 
				
			||||||
 | 
					                raise ValueError(f"No lock found for build variant: {variant}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            validate_kernel(
 | 
				
			||||||
 | 
					                repo_path=repo_path, variant=variant, hash=variant_lock.hash
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return repo_path / "build"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
 | 
				
			||||||
 | 
					    package_name, package_path = install_kernel(repo_id, revision=revision)
 | 
				
			||||||
 | 
					    return import_from_path(package_name, package_path / package_name / "__init__.py")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def has_kernel(repo_id: str, revision: str = "main") -> bool:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Check whether a kernel build exists for the current environment
 | 
				
			||||||
 | 
					    (Torch version and compute framework).
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    package_name = package_name_from_repo_id(repo_id)
 | 
				
			||||||
 | 
					    variant = build_variant()
 | 
				
			||||||
 | 
					    universal_variant = universal_build_variant()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if file_exists(
 | 
				
			||||||
 | 
					        repo_id,
 | 
				
			||||||
 | 
					        revision=revision,
 | 
				
			||||||
 | 
					        filename=f"build/{universal_variant}/{package_name}/__init__.py",
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return file_exists(
 | 
				
			||||||
 | 
					        repo_id,
 | 
				
			||||||
 | 
					        revision=revision,
 | 
				
			||||||
 | 
					        filename=f"build/{variant}/{package_name}/__init__.py",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Get a pre-downloaded, locked kernel.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    If `lockfile` is not specified, the lockfile will be loaded from the
 | 
				
			||||||
 | 
					    caller's package metadata.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if lockfile is None:
 | 
				
			||||||
 | 
					        locked_sha = _get_caller_locked_kernel(repo_id)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        with open(lockfile, "r") as f:
 | 
				
			||||||
 | 
					            locked_sha = _get_locked_kernel(repo_id, f.read())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if locked_sha is None:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    package_name = package_name_from_repo_id(repo_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    variant = build_variant()
 | 
				
			||||||
 | 
					    universal_variant = universal_build_variant()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    repo_path = Path(
 | 
				
			||||||
 | 
					        snapshot_download(
 | 
				
			||||||
 | 
					            repo_id,
 | 
				
			||||||
 | 
					            allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
 | 
				
			||||||
 | 
					            cache_dir=CACHE_DIR,
 | 
				
			||||||
 | 
					            revision=locked_sha,
 | 
				
			||||||
 | 
					            local_files_only=True,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    variant_path = repo_path / "build" / variant
 | 
				
			||||||
 | 
					    universal_variant_path = repo_path / "build" / universal_variant
 | 
				
			||||||
 | 
					    if not variant_path.exists() and universal_variant_path.exists():
 | 
				
			||||||
 | 
					        # Fall back to universal variant.
 | 
				
			||||||
 | 
					        variant = universal_variant
 | 
				
			||||||
 | 
					        variant_path = universal_variant_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    module_init_path = variant_path / package_name / "__init__.py"
 | 
				
			||||||
 | 
					    if not os.path.exists(module_init_path):
 | 
				
			||||||
 | 
					        raise FileNotFoundError(
 | 
				
			||||||
 | 
					            f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return import_from_path(package_name, variant_path / package_name / "__init__.py")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
 | 
				
			||||||
 | 
					    """Get a kernel using a lock file."""
 | 
				
			||||||
 | 
					    locked_sha = _get_caller_locked_kernel(repo_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if locked_sha is None:
 | 
				
			||||||
 | 
					        raise ValueError(f"Kernel `{repo_id}` is not locked")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    package_name, package_path = install_kernel(
 | 
				
			||||||
 | 
					        repo_id, locked_sha, local_files_only=local_files_only
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return import_from_path(package_name, package_path / package_name / "__init__.py")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
 | 
				
			||||||
 | 
					    for dist in _get_caller_distributions():
 | 
				
			||||||
 | 
					        lock_json = dist.read_text("kernels.lock")
 | 
				
			||||||
 | 
					        if lock_json is None:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        locked_sha = _get_locked_kernel(repo_id, lock_json)
 | 
				
			||||||
 | 
					        if locked_sha is not None:
 | 
				
			||||||
 | 
					            return locked_sha
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]:
 | 
				
			||||||
 | 
					    for kernel_lock_json in json.loads(lock_json):
 | 
				
			||||||
 | 
					        kernel_lock = KernelLock.from_json(kernel_lock_json)
 | 
				
			||||||
 | 
					        if kernel_lock.repo_id == repo_id:
 | 
				
			||||||
 | 
					            return kernel_lock.sha
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_caller_distributions() -> List[Distribution]:
 | 
				
			||||||
 | 
					    module = _get_caller_module()
 | 
				
			||||||
 | 
					    if module is None:
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Look up all possible distributions that this module could be from.
 | 
				
			||||||
 | 
					    package = module.__name__.split(".")[0]
 | 
				
			||||||
 | 
					    dist_names = importlib.metadata.packages_distributions().get(package)
 | 
				
			||||||
 | 
					    if dist_names is None:
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_caller_module() -> Optional[ModuleType]:
 | 
				
			||||||
 | 
					    stack = inspect.stack()
 | 
				
			||||||
 | 
					    # Get first module in the stack that is not the current module.
 | 
				
			||||||
 | 
					    first_module = inspect.getmodule(stack[0][0])
 | 
				
			||||||
 | 
					    for frame in stack[1:]:
 | 
				
			||||||
 | 
					        module = inspect.getmodule(frame[0])
 | 
				
			||||||
 | 
					        if module is not None and module != first_module:
 | 
				
			||||||
 | 
					            return module
 | 
				
			||||||
 | 
					    return first_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def validate_kernel(*, repo_path: Path, variant: str, hash: str):
 | 
				
			||||||
 | 
					    """Validate the given build variant of a kernel against a hasht."""
 | 
				
			||||||
 | 
					    variant_path = repo_path / "build" / variant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Get the file paths. The first element is a byte-encoded relative path
 | 
				
			||||||
 | 
					    # used for sorting. The second element is the absolute path.
 | 
				
			||||||
 | 
					    files: List[Tuple[bytes, Path]] = []
 | 
				
			||||||
 | 
					    # Ideally we'd use Path.walk, but it's only available in Python 3.12.
 | 
				
			||||||
 | 
					    for dirpath, _, filenames in os.walk(variant_path):
 | 
				
			||||||
 | 
					        for filename in filenames:
 | 
				
			||||||
 | 
					            file_abs = Path(dirpath) / filename
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Python likes to create files when importing modules from the
 | 
				
			||||||
 | 
					            # cache, only hash files that are symlinked blobs.
 | 
				
			||||||
 | 
					            if file_abs.is_symlink():
 | 
				
			||||||
 | 
					                files.append(
 | 
				
			||||||
 | 
					                    (
 | 
				
			||||||
 | 
					                        file_abs.relative_to(variant_path).as_posix().encode("utf-8"),
 | 
				
			||||||
 | 
					                        file_abs,
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    m = hashlib.sha256()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for filename_bytes, full_path in sorted(files):
 | 
				
			||||||
 | 
					        m.update(filename_bytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        blob_filename = full_path.resolve().name
 | 
				
			||||||
 | 
					        if len(blob_filename) == 40:
 | 
				
			||||||
 | 
					            # SHA-1 hashed, so a Git blob.
 | 
				
			||||||
 | 
					            m.update(git_hash_object(full_path.read_bytes()))
 | 
				
			||||||
 | 
					        elif len(blob_filename) == 64:
 | 
				
			||||||
 | 
					            # SHA-256 hashed, so a Git LFS blob.
 | 
				
			||||||
 | 
					            m.update(hashlib.sha256(full_path.read_bytes()).digest())
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"Unexpected blob filename length: {len(blob_filename)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    computedHash = f"sha256-{m.hexdigest()}"
 | 
				
			||||||
 | 
					    if computedHash != hash:
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            f"Lock file specifies kernel with hash {hash}, but downloaded kernel has hash: {computedHash}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def git_hash_object(data: bytes, object_type: str = "blob"):
 | 
				
			||||||
 | 
					    """Calculate git SHA1 of data."""
 | 
				
			||||||
 | 
					    header = f"{object_type} {len(data)}\0".encode()
 | 
				
			||||||
 | 
					    m = hashlib.sha1()
 | 
				
			||||||
 | 
					    m.update(header)
 | 
				
			||||||
 | 
					    m.update(data)
 | 
				
			||||||
 | 
					    return m.digest()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def package_name_from_repo_id(repo_id: str) -> str:
 | 
				
			||||||
 | 
					    return repo_id.split("/")[-1].replace("-", "_")
 | 
				
			||||||
							
								
								
									
										66
									
								
								tests/kernel_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								tests/kernel_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
				
			|||||||
 | 
					[
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    "repo_id": "kernels-community/activation",
 | 
				
			||||||
 | 
					    "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
 | 
				
			||||||
 | 
					    "variants": {
 | 
				
			||||||
 | 
					      "torch25-cxx11-cu118-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch25-cxx11-cu121-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch25-cxx11-cu124-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch25-cxx98-cu118-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch25-cxx98-cu121-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch25-cxx98-cu124-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx11-cu118-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx11-cu124-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx11-cu126-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx98-cu118-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx98-cu124-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      },
 | 
				
			||||||
 | 
					      "torch26-cxx98-cu126-x86_64-linux": {
 | 
				
			||||||
 | 
					        "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  },
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    "repo_id": "kernels-community/triton-scaled-mm",
 | 
				
			||||||
 | 
					    "sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
 | 
				
			||||||
 | 
					    "variants": {
 | 
				
			||||||
 | 
					      "torch-universal": {
 | 
				
			||||||
 | 
					        "hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
 | 
				
			||||||
 | 
					        "hash_type": "git_lfs_concat"
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										3
									
								
								tests/kernel_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tests/kernel_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					[tool.kernels.dependencies]
 | 
				
			||||||
 | 
					"kernels-community/activation" = ">=0.0.2"
 | 
				
			||||||
 | 
					"kernels-community/triton-scaled-mm" = ">=0.0.2"
 | 
				
			||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from hf_kernels import get_kernel
 | 
					
 | 
				
			||||||
 | 
					from kernels import get_kernel, has_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
@ -8,6 +9,11 @@ def kernel():
 | 
				
			|||||||
    return get_kernel("kernels-community/activation")
 | 
					    return get_kernel("kernels-community/activation")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def universal_kernel():
 | 
				
			||||||
 | 
					    return get_kernel("kernels-community/triton-scaled-mm")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
def device():
 | 
					def device():
 | 
				
			||||||
    if not torch.cuda.is_available():
 | 
					    if not torch.cuda.is_available():
 | 
				
			||||||
@ -28,3 +34,33 @@ def test_gelu_fast(kernel, device):
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert torch.allclose(y, expected)
 | 
					    assert torch.allclose(y, expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    "kernel_exists",
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("kernels-community/activation", "main", True),
 | 
				
			||||||
 | 
					        ("kernels-community/triton-layer-norm", "main", True),
 | 
				
			||||||
 | 
					        # Repo only contains Torch 2.4 kernels (and we don't
 | 
				
			||||||
 | 
					        # support/test against this version).
 | 
				
			||||||
 | 
					        ("kernels-test/only-torch-2.4", "main", False),
 | 
				
			||||||
 | 
					        ("google-bert/bert-base-uncased", "87565a309", False),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_has_kernel(kernel_exists):
 | 
				
			||||||
 | 
					    repo_id, revision, kernel = kernel_exists
 | 
				
			||||||
 | 
					    assert has_kernel(repo_id, revision=revision) == kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_universal_kernel(universal_kernel):
 | 
				
			||||||
 | 
					    torch.manual_seed(0)
 | 
				
			||||||
 | 
					    A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
 | 
				
			||||||
 | 
					    B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
 | 
				
			||||||
 | 
					    scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
 | 
				
			||||||
 | 
					    scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
 | 
				
			||||||
 | 
					    out_check = (A * scale_a) @ (B * scale_b)
 | 
				
			||||||
 | 
					    out_check = out_check.to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,7 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from hf_kernels import get_kernel
 | 
					
 | 
				
			||||||
 | 
					from kernels import get_kernel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										24
									
								
								tests/test_kernel_locking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								tests/test_kernel_locking.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from pathlib import Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from kernels import load_kernel
 | 
				
			||||||
 | 
					from kernels.cli import download_kernels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Mock download arguments class.
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class DownloadArgs:
 | 
				
			||||||
 | 
					    all_variants: bool
 | 
				
			||||||
 | 
					    project_dir: Path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_download_all_hash_validation():
 | 
				
			||||||
 | 
					    project_dir = Path(__file__).parent / "kernel_locking"
 | 
				
			||||||
 | 
					    download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_load_locked():
 | 
				
			||||||
 | 
					    project_dir = Path(__file__).parent / "kernel_locking"
 | 
				
			||||||
 | 
					    # Also validates that hashing works correctly.
 | 
				
			||||||
 | 
					    download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
 | 
				
			||||||
 | 
					    load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
 | 
				
			||||||
							
								
								
									
										277
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										277
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,277 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.nn as nn
 | 
				
			||||||
 | 
					from torch.nn import functional as F
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from kernels import (
 | 
				
			||||||
 | 
					    Device,
 | 
				
			||||||
 | 
					    LayerRepository,
 | 
				
			||||||
 | 
					    register_kernel_mapping,
 | 
				
			||||||
 | 
					    use_kernel_forward_from_hub,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					kernel_layer_mapping = {
 | 
				
			||||||
 | 
					    "SiluAndMul": {
 | 
				
			||||||
 | 
					        Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					            repo_id="kernels-community/activation",
 | 
				
			||||||
 | 
					            layer_name="SiluAndMul",
 | 
				
			||||||
 | 
					            revision="layers",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					    "SiluAndMulStringDevice": {
 | 
				
			||||||
 | 
					        "cuda": LayerRepository(
 | 
				
			||||||
 | 
					            repo_id="kernels-community/activation",
 | 
				
			||||||
 | 
					            layer_name="SiluAndMul",
 | 
				
			||||||
 | 
					            revision="layers",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    },
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					register_kernel_mapping(kernel_layer_mapping)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SiluAndMul(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        # Used to check that we called hub kernel.
 | 
				
			||||||
 | 
					        self.n_calls = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					        self.n_calls += 1
 | 
				
			||||||
 | 
					        d = input.shape[-1] // 2
 | 
				
			||||||
 | 
					        return F.silu(input[..., :d]) * input[..., d:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@use_kernel_forward_from_hub("SiluAndMul")
 | 
				
			||||||
 | 
					class SiluAndMulWithKernel(SiluAndMul):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@use_kernel_forward_from_hub("SiluAndMulStringDevice")
 | 
				
			||||||
 | 
					class SiluAndMulStringDevice(SiluAndMul):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_arg_kinds():
 | 
				
			||||||
 | 
					    @use_kernel_forward_from_hub("ArgKind")
 | 
				
			||||||
 | 
					    class ArgKind(nn.Module):
 | 
				
			||||||
 | 
					        def forward(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            arg1,
 | 
				
			||||||
 | 
					            arg2,
 | 
				
			||||||
 | 
					            *,
 | 
				
			||||||
 | 
					            kwarg1,
 | 
				
			||||||
 | 
					            kwarg2=42,
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
 | 
					            return (arg1, arg2, kwarg1, kwarg2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    arg_kind = ArgKind()
 | 
				
			||||||
 | 
					    assert arg_kind("foo", "bar", kwarg1="baz") == ("foo", "bar", "baz", 42)
 | 
				
			||||||
 | 
					    assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("device", ["cuda", "cpu"])
 | 
				
			||||||
 | 
					def test_hub_forward(cls, device):
 | 
				
			||||||
 | 
					    torch.random.manual_seed(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    silu_and_mul = SiluAndMul()
 | 
				
			||||||
 | 
					    X = torch.randn((32, 64), device=device)
 | 
				
			||||||
 | 
					    Y = silu_and_mul(X)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    silu_and_mul_with_kernel = cls()
 | 
				
			||||||
 | 
					    Y_kernel = silu_and_mul_with_kernel(X)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    torch.testing.assert_close(Y_kernel, Y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert silu_and_mul.n_calls == 1
 | 
				
			||||||
 | 
					    if device == "cuda":
 | 
				
			||||||
 | 
					        assert silu_and_mul_with_kernel.n_calls == 0
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        assert silu_and_mul_with_kernel.n_calls == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_layer_fallback_works():
 | 
				
			||||||
 | 
					    @use_kernel_forward_from_hub("SiluAndMulNonExisting")
 | 
				
			||||||
 | 
					    class SiluAndMulWithKernelFallback(SiluAndMul):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Check that we don't raise an exception for a non-existing kernel.
 | 
				
			||||||
 | 
					    SiluAndMulWithKernelFallback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_mapping_contexts():
 | 
				
			||||||
 | 
					    assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    extra_mapping1 = {
 | 
				
			||||||
 | 
					        "TestKernel": {
 | 
				
			||||||
 | 
					            Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					                repo_id="kernels-community/activation",
 | 
				
			||||||
 | 
					                layer_name="SiluAndMul",
 | 
				
			||||||
 | 
					                revision="layers",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with use_kernel_mapping(extra_mapping1):
 | 
				
			||||||
 | 
					        assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					            "SiluAndMul",
 | 
				
			||||||
 | 
					            "SiluAndMulStringDevice",
 | 
				
			||||||
 | 
					            "TestKernel",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        extra_mapping2 = {
 | 
				
			||||||
 | 
					            "SiluAndMul": {
 | 
				
			||||||
 | 
					                Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					                    repo_id="kernels-community/non-existing",
 | 
				
			||||||
 | 
					                    layer_name="SiluAndMul",
 | 
				
			||||||
 | 
					                    revision="layers",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with use_kernel_mapping(extra_mapping2):
 | 
				
			||||||
 | 
					            assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					                "SiluAndMul",
 | 
				
			||||||
 | 
					                "SiluAndMulStringDevice",
 | 
				
			||||||
 | 
					                "TestKernel",
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            assert (
 | 
				
			||||||
 | 
					                _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
 | 
				
			||||||
 | 
					                == "kernels-community/non-existing"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					            "SiluAndMul",
 | 
				
			||||||
 | 
					            "SiluAndMulStringDevice",
 | 
				
			||||||
 | 
					            "TestKernel",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        assert (
 | 
				
			||||||
 | 
					            _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
 | 
				
			||||||
 | 
					            == "kernels-community/activation"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with use_kernel_mapping(extra_mapping2, inherit_mapping=False):
 | 
				
			||||||
 | 
					            assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					                "SiluAndMul",
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            assert (
 | 
				
			||||||
 | 
					                _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
 | 
				
			||||||
 | 
					                == "kernels-community/non-existing"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					            "SiluAndMul",
 | 
				
			||||||
 | 
					            "SiluAndMulStringDevice",
 | 
				
			||||||
 | 
					            "TestKernel",
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        assert (
 | 
				
			||||||
 | 
					            _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
 | 
				
			||||||
 | 
					            == "kernels-community/activation"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert set(_KERNEL_MAPPING.get().keys()) == {
 | 
				
			||||||
 | 
					        "SiluAndMul",
 | 
				
			||||||
 | 
					        "SiluAndMulStringDevice",
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_validate_kernel_layer():
 | 
				
			||||||
 | 
					    class BadLayer(nn.Module):
 | 
				
			||||||
 | 
					        def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					            self.foo = 42
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError, match="not override"):
 | 
				
			||||||
 | 
					        _validate_layer(cls=BadLayer, check_cls=SiluAndMul)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class BadLayer2(nn.Module):
 | 
				
			||||||
 | 
					        foo: int = 42
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError, match="not contain additional members"):
 | 
				
			||||||
 | 
					        _validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class BadLayer3(nn.Module):
 | 
				
			||||||
 | 
					        def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError, match="different number of arguments"):
 | 
				
			||||||
 | 
					        _validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class BadLayer4(nn.Module):
 | 
				
			||||||
 | 
					        def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(TypeError, match="different kind of arguments"):
 | 
				
			||||||
 | 
					        _validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_fallback_used_when_training():
 | 
				
			||||||
 | 
					    @use_kernel_forward_from_hub("Linear")
 | 
				
			||||||
 | 
					    class TorchLinear(nn.Linear):
 | 
				
			||||||
 | 
					        def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					            super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					            # Used to check that we called hub kernel.
 | 
				
			||||||
 | 
					            self.n_calls = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
				
			||||||
 | 
					            self.n_calls += 1
 | 
				
			||||||
 | 
					            return super().forward(input)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    linear = TorchLinear(32, 32).to("cuda")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with use_kernel_mapping(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "Linear": {
 | 
				
			||||||
 | 
					                Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					                    repo_id="kernels-test/backward-marker-test",
 | 
				
			||||||
 | 
					                    layer_name="LinearImplicitBackward",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        linear.train()
 | 
				
			||||||
 | 
					        X = torch.randn(10, 32, device="cuda")
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        linear.eval()
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with use_kernel_mapping(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "Linear": {
 | 
				
			||||||
 | 
					                Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					                    repo_id="kernels-test/backward-marker-test",
 | 
				
			||||||
 | 
					                    layer_name="LinearBackward",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        linear.train()
 | 
				
			||||||
 | 
					        X = torch.randn(10, 32, device="cuda")
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        linear.eval()
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with use_kernel_mapping(
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            "Linear": {
 | 
				
			||||||
 | 
					                Device(type="cuda"): LayerRepository(
 | 
				
			||||||
 | 
					                    repo_id="kernels-test/backward-marker-test",
 | 
				
			||||||
 | 
					                    layer_name="LinearNoBackward",
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        linear.train()
 | 
				
			||||||
 | 
					        X = torch.randn(10, 32, device="cuda")
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        linear.eval()
 | 
				
			||||||
 | 
					        linear(X)
 | 
				
			||||||
 | 
					        assert linear.n_calls == 1
 | 
				
			||||||
		Reference in New Issue
	
	Block a user