mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-11-04 05:55:32 +08:00 
			
		
		
		
	Compare commits
	
		
			86 Commits
		
	
	
		
			v0.3.1
			...
			release-0.
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0a3c828359 | |||
| dfee307d54 | |||
| 93e5765611 | |||
| bf488208be | |||
| 2a14472e4c | |||
| 055a953552 | |||
| 692d5ad458 | |||
| 2139df57f4 | |||
| 8f9a77bb6a | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 | |||
| 07e5e8481a | |||
| 88f55d4728 | |||
| e801ebf332 | |||
| 0ae07f05fc | |||
| 7611021100 | |||
| 767e7ccf13 | |||
| 1caa4c1393 | |||
| da701bf58a | |||
| 703664ed31 | |||
| a8a6564fa7 | |||
| c89e0fa9b9 | |||
| 176a601178 | |||
| cfa0c76ddc | |||
| bcc29915f9 | |||
| 6fbff7a9cb | |||
| f7490bd0a9 | |||
| 8069e3bf0c | |||
| c540d1e1d6 | |||
| 967ac581b8 | |||
| 81088d44e8 | |||
| 4a04c005e3 | |||
| 6d3c6daf20 | |||
| 071900fd69 | |||
| 2d2c6b14e0 | |||
| 03edc573b1 | |||
| c841a6c90d | |||
| c7a343f195 | |||
| 8d838f947d | |||
| b87e6fadbe | |||
| fc935d9874 | |||
| 3622e1f8dd | |||
| a7f3b2e8ed | |||
| a6ab5d83ba | |||
| 4f9f1abfb9 | |||
| f94b7780a6 | |||
| bd28883775 | |||
| 498429e322 | |||
| 09c991af4b | |||
| bcf8df5875 | |||
| 239afff6f5 | |||
| c5ec6b900a | |||
| 3a635eaeea | |||
| 32ec496c5a | |||
| 848c6db87b | |||
| fabb8c52d1 | |||
| d66260dd83 | |||
| daac8078fc | |||
| fcb9a80ce6 | |||
| c25bb32e6e | |||
| 2036892762 | |||
| 0f0de049cf | |||
| 59597df03e | |||
| 5e938ede40 | |||
| cf530c283a | |||
| 437f910336 | |||
| 6f1a6067c8 | |||
| 1d14abcef0 | |||
| 6fd2112e22 | |||
| 70f56ff856 | |||
| 7178b0b86c | |||
| 0bbf90a564 | |||
| 27d6ffcb80 | |||
| f7bd21438b | |||
| 6174febb4b | |||
| ff55bc201b | |||
| 3808108d62 | |||
| c4a16ef462 | |||
| 9762794dd2 | |||
| b7d6867c52 | |||
| fbcd0f2ebd | |||
| 5af46eca94 | |||
| 747dd66876 | |||
| 920590a592 | |||
| 5208ac4be5 | 
							
								
								
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,17 @@
 | 
			
		||||
name: Build documentation
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
      - doc-builder*
 | 
			
		||||
      - v*-release
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build:
 | 
			
		||||
    uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      commit_sha: ${{ github.sha }}
 | 
			
		||||
      package: kernels
 | 
			
		||||
    secrets:
 | 
			
		||||
      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,15 @@
 | 
			
		||||
name: Build PR Documentation
 | 
			
		||||
 | 
			
		||||
on: pull_request
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build:
 | 
			
		||||
    uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      commit_sha: ${{ github.event.pull_request.head.sha }}
 | 
			
		||||
      pr_number: ${{ github.event.number }}
 | 
			
		||||
      package: kernels
 | 
			
		||||
							
								
								
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										21
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
								
							@ -8,3 +8,24 @@ jobs:
 | 
			
		||||
      - uses: actions/checkout@v4
 | 
			
		||||
      - name: Run ruff
 | 
			
		||||
        uses: astral-sh/ruff-action@v3
 | 
			
		||||
 | 
			
		||||
  black:
 | 
			
		||||
    name: Run black check
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    env:
 | 
			
		||||
      UV_PYTHON_PREFERENCE: only-managed
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v4
 | 
			
		||||
 | 
			
		||||
      - name: Install uv and set the python version
 | 
			
		||||
        uses: astral-sh/setup-uv@v5
 | 
			
		||||
        with:
 | 
			
		||||
          python-version: 3.12
 | 
			
		||||
 | 
			
		||||
      - name: Install black
 | 
			
		||||
        run: uv pip install black
 | 
			
		||||
 | 
			
		||||
      - name: Check formatting
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run black --check src
 | 
			
		||||
          uv run black --check tests
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,120 @@
 | 
			
		||||
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
 | 
			
		||||
 | 
			
		||||
on: push
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build:
 | 
			
		||||
    name: Build distribution 📦
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          persist-credentials: false
 | 
			
		||||
      - name: Set up Python
 | 
			
		||||
        uses: actions/setup-python@v5
 | 
			
		||||
        with:
 | 
			
		||||
          python-version: "3.9"
 | 
			
		||||
      - name: Install pypa/build
 | 
			
		||||
        run: >-
 | 
			
		||||
          python3 -m
 | 
			
		||||
          pip install
 | 
			
		||||
          build
 | 
			
		||||
          --user
 | 
			
		||||
      - name: Build a binary wheel and a source tarball
 | 
			
		||||
        run: python3 -m build
 | 
			
		||||
      - name: Store the distribution packages
 | 
			
		||||
        uses: actions/upload-artifact@v4
 | 
			
		||||
        with:
 | 
			
		||||
          name: python-package-distributions
 | 
			
		||||
          path: dist/
 | 
			
		||||
 | 
			
		||||
  publish-to-pypi:
 | 
			
		||||
    name: >-
 | 
			
		||||
      Publish Python 🐍 distribution 📦 to PyPI
 | 
			
		||||
    if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
 | 
			
		||||
    needs:
 | 
			
		||||
      - build
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    environment:
 | 
			
		||||
      name: pypi
 | 
			
		||||
      url: https://pypi.org/p/kernels
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write # IMPORTANT: mandatory for trusted publishing
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Download all the dists
 | 
			
		||||
        uses: actions/download-artifact@v4
 | 
			
		||||
        with:
 | 
			
		||||
          name: python-package-distributions
 | 
			
		||||
          path: dist/
 | 
			
		||||
      - name: Publish distribution 📦 to PyPI
 | 
			
		||||
        uses: pypa/gh-action-pypi-publish@release/v1
 | 
			
		||||
 | 
			
		||||
  github-release:
 | 
			
		||||
    name: >-
 | 
			
		||||
      Sign the Python 🐍 distribution 📦 with Sigstore
 | 
			
		||||
      and upload them to GitHub Release
 | 
			
		||||
    needs:
 | 
			
		||||
      - publish-to-pypi
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
 | 
			
		||||
    permissions:
 | 
			
		||||
      contents: write # IMPORTANT: mandatory for making GitHub Releases
 | 
			
		||||
      id-token: write # IMPORTANT: mandatory for sigstore
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Download all the dists
 | 
			
		||||
        uses: actions/download-artifact@v4
 | 
			
		||||
        with:
 | 
			
		||||
          name: python-package-distributions
 | 
			
		||||
          path: dist/
 | 
			
		||||
      - name: Sign the dists with Sigstore
 | 
			
		||||
        uses: sigstore/gh-action-sigstore-python@v3.0.0
 | 
			
		||||
        with:
 | 
			
		||||
          inputs: >-
 | 
			
		||||
            ./dist/*.tar.gz
 | 
			
		||||
            ./dist/*.whl
 | 
			
		||||
      - name: Create GitHub Release
 | 
			
		||||
        env:
 | 
			
		||||
          GITHUB_TOKEN: ${{ github.token }}
 | 
			
		||||
        run: >-
 | 
			
		||||
          gh release create
 | 
			
		||||
          "$GITHUB_REF_NAME"
 | 
			
		||||
          --repo "$GITHUB_REPOSITORY"
 | 
			
		||||
          --notes ""
 | 
			
		||||
      - name: Upload artifact signatures to GitHub Release
 | 
			
		||||
        env:
 | 
			
		||||
          GITHUB_TOKEN: ${{ github.token }}
 | 
			
		||||
        # Upload to GitHub Release using the `gh` CLI.
 | 
			
		||||
        # `dist/` contains the built packages, and the
 | 
			
		||||
        # sigstore-produced signatures and certificates.
 | 
			
		||||
        run: >-
 | 
			
		||||
          gh release upload
 | 
			
		||||
          "$GITHUB_REF_NAME" dist/**
 | 
			
		||||
          --repo "$GITHUB_REPOSITORY"
 | 
			
		||||
 | 
			
		||||
  publish-to-testpypi:
 | 
			
		||||
    name: Publish Python 🐍 distribution 📦 to TestPyPI
 | 
			
		||||
    needs:
 | 
			
		||||
      - build
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
 | 
			
		||||
    environment:
 | 
			
		||||
      name: testpypi
 | 
			
		||||
      url: https://test.pypi.org/p/kernels
 | 
			
		||||
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write # IMPORTANT: mandatory for trusted publishing
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Download all the dists
 | 
			
		||||
        uses: actions/download-artifact@v4
 | 
			
		||||
        with:
 | 
			
		||||
          name: python-package-distributions
 | 
			
		||||
          path: dist/
 | 
			
		||||
      - name: Publish distribution 📦 to TestPyPI
 | 
			
		||||
        uses: pypa/gh-action-pypi-publish@release/v1
 | 
			
		||||
        with:
 | 
			
		||||
          repository-url: https://test.pypi.org/legacy/
 | 
			
		||||
          skip-existing: true # Only upload when the version is unique.
 | 
			
		||||
							
								
								
									
										28
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										28
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							@ -24,7 +24,7 @@ jobs:
 | 
			
		||||
      max-parallel: 4
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.10", "3.12"]
 | 
			
		||||
        torch-version: ["2.5.1", "2.6.0"]
 | 
			
		||||
        torch-version: ["2.6.0", "2.7.0"]
 | 
			
		||||
 | 
			
		||||
    env:
 | 
			
		||||
      UV_PYTHON_PREFERENCE: only-managed
 | 
			
		||||
@ -51,4 +51,28 @@ jobs:
 | 
			
		||||
        run: uv run mypy src/kernels
 | 
			
		||||
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        run: uv run pytest tests
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run pytest tests
 | 
			
		||||
 | 
			
		||||
      - name: Run staging tests
 | 
			
		||||
        env: 
 | 
			
		||||
          HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }}
 | 
			
		||||
        run: |
 | 
			
		||||
          HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/
 | 
			
		||||
 | 
			
		||||
      - name: Check kernel conversion
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip install wheel
 | 
			
		||||
          uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
 | 
			
		||||
          uv pip install triton_layer_norm-0.0.1*.whl
 | 
			
		||||
          uv run python -c "import triton_layer_norm"
 | 
			
		||||
 | 
			
		||||
      - name: Check README generation
 | 
			
		||||
        # For now, just checks that generation doesn't fail.
 | 
			
		||||
        run: |
 | 
			
		||||
          uv run kernels generate-readme kernels-community/triton-layer-norm
 | 
			
		||||
 | 
			
		||||
      - name: Import check without torch
 | 
			
		||||
        run: |
 | 
			
		||||
          uv pip uninstall torch
 | 
			
		||||
          python -c "import kernels"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								.github/workflows/upload_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
name: Upload PR Documentation
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  workflow_run:
 | 
			
		||||
    workflows: ["Build PR Documentation"]
 | 
			
		||||
    types:
 | 
			
		||||
      - completed
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  build:
 | 
			
		||||
    uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
 | 
			
		||||
    with:
 | 
			
		||||
      package_name: kernels
 | 
			
		||||
    secrets:
 | 
			
		||||
      hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
 | 
			
		||||
      comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
 | 
			
		||||
							
								
								
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,201 @@
 | 
			
		||||
                                 Apache License
 | 
			
		||||
                           Version 2.0, January 2004
 | 
			
		||||
                        http://www.apache.org/licenses/
 | 
			
		||||
 | 
			
		||||
   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
			
		||||
 | 
			
		||||
   1. Definitions.
 | 
			
		||||
 | 
			
		||||
      "License" shall mean the terms and conditions for use, reproduction,
 | 
			
		||||
      and distribution as defined by Sections 1 through 9 of this document.
 | 
			
		||||
 | 
			
		||||
      "Licensor" shall mean the copyright owner or entity authorized by
 | 
			
		||||
      the copyright owner that is granting the License.
 | 
			
		||||
 | 
			
		||||
      "Legal Entity" shall mean the union of the acting entity and all
 | 
			
		||||
      other entities that control, are controlled by, or are under common
 | 
			
		||||
      control with that entity. For the purposes of this definition,
 | 
			
		||||
      "control" means (i) the power, direct or indirect, to cause the
 | 
			
		||||
      direction or management of such entity, whether by contract or
 | 
			
		||||
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
			
		||||
      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
			
		||||
 | 
			
		||||
      "You" (or "Your") shall mean an individual or Legal Entity
 | 
			
		||||
      exercising permissions granted by this License.
 | 
			
		||||
 | 
			
		||||
      "Source" form shall mean the preferred form for making modifications,
 | 
			
		||||
      including but not limited to software source code, documentation
 | 
			
		||||
      source, and configuration files.
 | 
			
		||||
 | 
			
		||||
      "Object" form shall mean any form resulting from mechanical
 | 
			
		||||
      transformation or translation of a Source form, including but
 | 
			
		||||
      not limited to compiled object code, generated documentation,
 | 
			
		||||
      and conversions to other media types.
 | 
			
		||||
 | 
			
		||||
      "Work" shall mean the work of authorship, whether in Source or
 | 
			
		||||
      Object form, made available under the License, as indicated by a
 | 
			
		||||
      copyright notice that is included in or attached to the work
 | 
			
		||||
      (an example is provided in the Appendix below).
 | 
			
		||||
 | 
			
		||||
      "Derivative Works" shall mean any work, whether in Source or Object
 | 
			
		||||
      form, that is based on (or derived from) the Work and for which the
 | 
			
		||||
      editorial revisions, annotations, elaborations, or other modifications
 | 
			
		||||
      represent, as a whole, an original work of authorship. For the purposes
 | 
			
		||||
      of this License, Derivative Works shall not include works that remain
 | 
			
		||||
      separable from, or merely link (or bind by name) to the interfaces of,
 | 
			
		||||
      the Work and Derivative Works thereof.
 | 
			
		||||
 | 
			
		||||
      "Contribution" shall mean any work of authorship, including
 | 
			
		||||
      the original version of the Work and any modifications or additions
 | 
			
		||||
      to that Work or Derivative Works thereof, that is intentionally
 | 
			
		||||
      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
			
		||||
      or by an individual or Legal Entity authorized to submit on behalf of
 | 
			
		||||
      the copyright owner. For the purposes of this definition, "submitted"
 | 
			
		||||
      means any form of electronic, verbal, or written communication sent
 | 
			
		||||
      to the Licensor or its representatives, including but not limited to
 | 
			
		||||
      communication on electronic mailing lists, source code control systems,
 | 
			
		||||
      and issue tracking systems that are managed by, or on behalf of, the
 | 
			
		||||
      Licensor for the purpose of discussing and improving the Work, but
 | 
			
		||||
      excluding communication that is conspicuously marked or otherwise
 | 
			
		||||
      designated in writing by the copyright owner as "Not a Contribution."
 | 
			
		||||
 | 
			
		||||
      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
			
		||||
      on behalf of whom a Contribution has been received by Licensor and
 | 
			
		||||
      subsequently incorporated within the Work.
 | 
			
		||||
 | 
			
		||||
   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      copyright license to reproduce, prepare Derivative Works of,
 | 
			
		||||
      publicly display, publicly perform, sublicense, and distribute the
 | 
			
		||||
      Work and such Derivative Works in Source or Object form.
 | 
			
		||||
 | 
			
		||||
   3. Grant of Patent License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      (except as stated in this section) patent license to make, have made,
 | 
			
		||||
      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
			
		||||
      where such license applies only to those patent claims licensable
 | 
			
		||||
      by such Contributor that are necessarily infringed by their
 | 
			
		||||
      Contribution(s) alone or by combination of their Contribution(s)
 | 
			
		||||
      with the Work to which such Contribution(s) was submitted. If You
 | 
			
		||||
      institute patent litigation against any entity (including a
 | 
			
		||||
      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
			
		||||
      or a Contribution incorporated within the Work constitutes direct
 | 
			
		||||
      or contributory patent infringement, then any patent licenses
 | 
			
		||||
      granted to You under this License for that Work shall terminate
 | 
			
		||||
      as of the date such litigation is filed.
 | 
			
		||||
 | 
			
		||||
   4. Redistribution. You may reproduce and distribute copies of the
 | 
			
		||||
      Work or Derivative Works thereof in any medium, with or without
 | 
			
		||||
      modifications, and in Source or Object form, provided that You
 | 
			
		||||
      meet the following conditions:
 | 
			
		||||
 | 
			
		||||
      (a) You must give any other recipients of the Work or
 | 
			
		||||
          Derivative Works a copy of this License; and
 | 
			
		||||
 | 
			
		||||
      (b) You must cause any modified files to carry prominent notices
 | 
			
		||||
          stating that You changed the files; and
 | 
			
		||||
 | 
			
		||||
      (c) You must retain, in the Source form of any Derivative Works
 | 
			
		||||
          that You distribute, all copyright, patent, trademark, and
 | 
			
		||||
          attribution notices from the Source form of the Work,
 | 
			
		||||
          excluding those notices that do not pertain to any part of
 | 
			
		||||
          the Derivative Works; and
 | 
			
		||||
 | 
			
		||||
      (d) If the Work includes a "NOTICE" text file as part of its
 | 
			
		||||
          distribution, then any Derivative Works that You distribute must
 | 
			
		||||
          include a readable copy of the attribution notices contained
 | 
			
		||||
          within such NOTICE file, excluding those notices that do not
 | 
			
		||||
          pertain to any part of the Derivative Works, in at least one
 | 
			
		||||
          of the following places: within a NOTICE text file distributed
 | 
			
		||||
          as part of the Derivative Works; within the Source form or
 | 
			
		||||
          documentation, if provided along with the Derivative Works; or,
 | 
			
		||||
          within a display generated by the Derivative Works, if and
 | 
			
		||||
          wherever such third-party notices normally appear. The contents
 | 
			
		||||
          of the NOTICE file are for informational purposes only and
 | 
			
		||||
          do not modify the License. You may add Your own attribution
 | 
			
		||||
          notices within Derivative Works that You distribute, alongside
 | 
			
		||||
          or as an addendum to the NOTICE text from the Work, provided
 | 
			
		||||
          that such additional attribution notices cannot be construed
 | 
			
		||||
          as modifying the License.
 | 
			
		||||
 | 
			
		||||
      You may add Your own copyright statement to Your modifications and
 | 
			
		||||
      may provide additional or different license terms and conditions
 | 
			
		||||
      for use, reproduction, or distribution of Your modifications, or
 | 
			
		||||
      for any such Derivative Works as a whole, provided Your use,
 | 
			
		||||
      reproduction, and distribution of the Work otherwise complies with
 | 
			
		||||
      the conditions stated in this License.
 | 
			
		||||
 | 
			
		||||
   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
			
		||||
      any Contribution intentionally submitted for inclusion in the Work
 | 
			
		||||
      by You to the Licensor shall be under the terms and conditions of
 | 
			
		||||
      this License, without any additional terms or conditions.
 | 
			
		||||
      Notwithstanding the above, nothing herein shall supersede or modify
 | 
			
		||||
      the terms of any separate license agreement you may have executed
 | 
			
		||||
      with Licensor regarding such Contributions.
 | 
			
		||||
 | 
			
		||||
   6. Trademarks. This License does not grant permission to use the trade
 | 
			
		||||
      names, trademarks, service marks, or product names of the Licensor,
 | 
			
		||||
      except as required for reasonable and customary use in describing the
 | 
			
		||||
      origin of the Work and reproducing the content of the NOTICE file.
 | 
			
		||||
 | 
			
		||||
   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
			
		||||
      agreed to in writing, Licensor provides the Work (and each
 | 
			
		||||
      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
			
		||||
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
      implied, including, without limitation, any warranties or conditions
 | 
			
		||||
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
			
		||||
      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
			
		||||
      appropriateness of using or redistributing the Work and assume any
 | 
			
		||||
      risks associated with Your exercise of permissions under this License.
 | 
			
		||||
 | 
			
		||||
   8. Limitation of Liability. In no event and under no legal theory,
 | 
			
		||||
      whether in tort (including negligence), contract, or otherwise,
 | 
			
		||||
      unless required by applicable law (such as deliberate and grossly
 | 
			
		||||
      negligent acts) or agreed to in writing, shall any Contributor be
 | 
			
		||||
      liable to You for damages, including any direct, indirect, special,
 | 
			
		||||
      incidental, or consequential damages of any character arising as a
 | 
			
		||||
      result of this License or out of the use or inability to use the
 | 
			
		||||
      Work (including but not limited to damages for loss of goodwill,
 | 
			
		||||
      work stoppage, computer failure or malfunction, or any and all
 | 
			
		||||
      other commercial damages or losses), even if such Contributor
 | 
			
		||||
      has been advised of the possibility of such damages.
 | 
			
		||||
 | 
			
		||||
   9. Accepting Warranty or Additional Liability. While redistributing
 | 
			
		||||
      the Work or Derivative Works thereof, You may choose to offer,
 | 
			
		||||
      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
			
		||||
      or other liability obligations and/or rights consistent with this
 | 
			
		||||
      License. However, in accepting such obligations, You may act only
 | 
			
		||||
      on Your own behalf and on Your sole responsibility, not on behalf
 | 
			
		||||
      of any other Contributor, and only if You agree to indemnify,
 | 
			
		||||
      defend, and hold each Contributor harmless for any liability
 | 
			
		||||
      incurred by, or claims asserted against, such Contributor by reason
 | 
			
		||||
      of your accepting any such warranty or additional liability.
 | 
			
		||||
 | 
			
		||||
   END OF TERMS AND CONDITIONS
 | 
			
		||||
 | 
			
		||||
   APPENDIX: How to apply the Apache License to your work.
 | 
			
		||||
 | 
			
		||||
      To apply the Apache License to your work, attach the following
 | 
			
		||||
      boilerplate notice, with the fields enclosed by brackets "[]"
 | 
			
		||||
      replaced with your own identifying information. (Don't include
 | 
			
		||||
      the brackets!)  The text should be enclosed in the appropriate
 | 
			
		||||
      comment syntax for the file format. We also recommend that a
 | 
			
		||||
      file or class name and description of purpose be included on the
 | 
			
		||||
      same "printed page" as the copyright notice for easier
 | 
			
		||||
      identification within third-party archives.
 | 
			
		||||
 | 
			
		||||
   Copyright [yyyy] [name of copyright owner]
 | 
			
		||||
 | 
			
		||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
   you may not use this file except in compliance with the License.
 | 
			
		||||
   You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
       http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
   Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
   See the License for the specific language governing permissions and
 | 
			
		||||
   limitations under the License.
 | 
			
		||||
							
								
								
									
										23
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								README.md
									
									
									
									
									
								
							@ -1,5 +1,16 @@
 | 
			
		||||
# kernels
 | 
			
		||||
 | 
			
		||||
<div align="center">
 | 
			
		||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
 | 
			
		||||
<p align="center">
 | 
			
		||||
    <a href="https://pypi.org/project/kernels"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/kernels"></a>
 | 
			
		||||
    <a href="https://github.com/huggingface/kernels/tags"><img alt="GitHub tag" src="https://img.shields.io/github/v/tag/huggingface/kernels"></a>
 | 
			
		||||
    <a href="https://github.com/huggingface/kernels/actions/workflows/docker-build-push.yaml"><img alt="Test kernels" src="https://img.shields.io/github/actions/workflow/status/huggingface/kernels/test.yml?label=test"></a>
 | 
			
		||||
  
 | 
			
		||||
</p>
 | 
			
		||||
</div>
 | 
			
		||||
<hr/>
 | 
			
		||||
 | 
			
		||||
The Kernel Hub allows Python libraries and applications to load compute
 | 
			
		||||
kernels directly from the [Hub](https://hf.co/). To support this kind
 | 
			
		||||
of dynamic loading, Hub kernels differ from traditional Python kernel
 | 
			
		||||
@ -45,8 +56,12 @@ the Hub.
 | 
			
		||||
 | 
			
		||||
## 📚 Documentation
 | 
			
		||||
 | 
			
		||||
- [Using layers](docs/layers.md)
 | 
			
		||||
- [Locking kernel versions](docs/locking.md)
 | 
			
		||||
- [Using kernels in a Docker container](docs/docker.md)
 | 
			
		||||
- [Kernel requirements](docs/kernel-requirements.md)
 | 
			
		||||
- [Introduction](docs/source/index.md)
 | 
			
		||||
- [Installation](docs/source/installation.md)
 | 
			
		||||
- [Basic usage](docs/source/basic-usage.md)
 | 
			
		||||
- [Using layers](docs/source/layers.md)
 | 
			
		||||
- [Locking kernel/layer versions](docs/source/locking.md)
 | 
			
		||||
- [Environment variables](docs/source/env.md)
 | 
			
		||||
- [Kernel requirements](docs/source/kernel-requirements.md)
 | 
			
		||||
- [Frequently Asked Questions](docs/source/faq.md)
 | 
			
		||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
 | 
			
		||||
 | 
			
		||||
@ -1,8 +0,0 @@
 | 
			
		||||
# Using kernels in a Docker container
 | 
			
		||||
 | 
			
		||||
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
 | 
			
		||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
 | 
			
		||||
```
 | 
			
		||||
@ -1,79 +0,0 @@
 | 
			
		||||
# Layers
 | 
			
		||||
 | 
			
		||||
A kernel can provide layers in addition to kernel functions. A layer from
 | 
			
		||||
the Hub can replace the `forward` method of an existing layer for a certain
 | 
			
		||||
device type. This makes it possible to provide more performant kernels for
 | 
			
		||||
existing layers.
 | 
			
		||||
 | 
			
		||||
See [Kernel requirements](kernel-requirements.md) for more information the
 | 
			
		||||
requirements of Hub layers.
 | 
			
		||||
 | 
			
		||||
## Making a layer extensible with kernels from the hub
 | 
			
		||||
 | 
			
		||||
### Using a decorator
 | 
			
		||||
 | 
			
		||||
A layer can be made extensible with the `use_kernel_forward_from_hub`
 | 
			
		||||
decorator. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
@use_kernel_forward_from_hub("SiluAndMul")
 | 
			
		||||
class SiluAndMul(nn.Module):
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        d = input.shape[-1] // 2
 | 
			
		||||
        return F.silu(input[..., :d]) * input[..., d:]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The decorator changes the layer, so that other implementations of the `forward`
 | 
			
		||||
method can be registered using the name `SiluAndMul`.
 | 
			
		||||
 | 
			
		||||
### External layers
 | 
			
		||||
 | 
			
		||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
 | 
			
		||||
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from somelibrary import SiluAndMul
 | 
			
		||||
 | 
			
		||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
 | 
			
		||||
register_kernel_mapping(kernel_layer_mapping)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
 | 
			
		||||
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
 | 
			
		||||
section for more information.
 | 
			
		||||
 | 
			
		||||
**Warning:** we strongly recommend using layers with a decorator, since
 | 
			
		||||
it signifies that the maintainer intends to keep the `forward` signature
 | 
			
		||||
compatible with layers from the hub.
 | 
			
		||||
 | 
			
		||||
## Registering a hub kernel for a layer
 | 
			
		||||
 | 
			
		||||
Once a layer is made extensible, users can register hub kernels for it
 | 
			
		||||
by name using the `register_kernel_mapping` function. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
            revision="layers",
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
register_kernel_mapping(kernel_layer_mapping)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will register the kernel mapping in the current context, which is
 | 
			
		||||
normally global. It is recommended to scope the mapping to where it is
 | 
			
		||||
used with the `use_kernel_mapping` context manager:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
with use_kernel_mapping(kernel_layer_mapping):
 | 
			
		||||
    # Use the layer for which the mapping is applied.
 | 
			
		||||
    ...
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This ensures that the mapping is not active anymore outside the
 | 
			
		||||
`with`-scope.
 | 
			
		||||
							
								
								
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
			
		||||
- sections:
 | 
			
		||||
    - local: index
 | 
			
		||||
      title: Introduction
 | 
			
		||||
    - local: installation
 | 
			
		||||
      title: Installation
 | 
			
		||||
  title: Getting started
 | 
			
		||||
- sections:
 | 
			
		||||
    - local: basic-usage
 | 
			
		||||
      title: Basic Usage
 | 
			
		||||
    - local: layers
 | 
			
		||||
      title: Using Layers
 | 
			
		||||
    - local: locking
 | 
			
		||||
      title: Locking Kernel Versions
 | 
			
		||||
    - local: env
 | 
			
		||||
      title: Environment Variables
 | 
			
		||||
    - local: faq
 | 
			
		||||
      title: FAQ
 | 
			
		||||
  title: Usage Guide
 | 
			
		||||
- sections:
 | 
			
		||||
    - local: api/kernels
 | 
			
		||||
      title: Kernels
 | 
			
		||||
    - local: api/layers
 | 
			
		||||
      title: Layers
 | 
			
		||||
    - local: cli
 | 
			
		||||
      title: Kernels CLI
 | 
			
		||||
  title: API Reference
 | 
			
		||||
- sections:
 | 
			
		||||
    - local: kernel-requirements
 | 
			
		||||
      title: Kernel Requirements
 | 
			
		||||
  title: Developer Guide
 | 
			
		||||
							
								
								
									
										21
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
			
		||||
# Kernels API Reference
 | 
			
		||||
 | 
			
		||||
## Main Functions
 | 
			
		||||
 | 
			
		||||
### get_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.get_kernel
 | 
			
		||||
 | 
			
		||||
### has_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.has_kernel
 | 
			
		||||
 | 
			
		||||
## Loading locked kernels
 | 
			
		||||
 | 
			
		||||
### load_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.load_kernel
 | 
			
		||||
 | 
			
		||||
### get_locked_kernel
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.get_locked_kernel
 | 
			
		||||
							
								
								
									
										41
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
# Layers API Reference
 | 
			
		||||
 | 
			
		||||
## Making layers kernel-aware
 | 
			
		||||
 | 
			
		||||
### use_kernel_forward_from_hub
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.use_kernel_forward_from_hub
 | 
			
		||||
 | 
			
		||||
### replace_kernel_forward_from_hub
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.replace_kernel_forward_from_hub
 | 
			
		||||
 | 
			
		||||
## Registering kernel mappings
 | 
			
		||||
 | 
			
		||||
### use_kernel_mapping
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.use_kernel_mapping
 | 
			
		||||
 | 
			
		||||
### register_kernel_mapping
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.register_kernel_mapping
 | 
			
		||||
 | 
			
		||||
## Kernelizing a model
 | 
			
		||||
 | 
			
		||||
### kernelize
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.kernelize
 | 
			
		||||
 | 
			
		||||
## Classes
 | 
			
		||||
 | 
			
		||||
### Device
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.Device
 | 
			
		||||
 | 
			
		||||
### Mode
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.Mode
 | 
			
		||||
 | 
			
		||||
### LayerRepository
 | 
			
		||||
 | 
			
		||||
[[autodoc]] kernels.LayerRepository
 | 
			
		||||
							
								
								
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								docs/source/basic-usage.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
# Basic Usage
 | 
			
		||||
 | 
			
		||||
## Loading Kernels
 | 
			
		||||
 | 
			
		||||
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import torch
 | 
			
		||||
from kernels import get_kernel
 | 
			
		||||
 | 
			
		||||
# Download optimized kernels from the Hugging Face hub
 | 
			
		||||
activation = get_kernel("kernels-community/activation")
 | 
			
		||||
 | 
			
		||||
# Create a random tensor
 | 
			
		||||
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
 | 
			
		||||
 | 
			
		||||
# Run the kernel
 | 
			
		||||
y = torch.empty_like(x)
 | 
			
		||||
activation.gelu_fast(y, x)
 | 
			
		||||
 | 
			
		||||
print(y)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Using version bounds
 | 
			
		||||
 | 
			
		||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
 | 
			
		||||
You can specify which version to download using Python version specifiers:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
import torch
 | 
			
		||||
from kernels import get_kernel
 | 
			
		||||
 | 
			
		||||
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
 | 
			
		||||
is strongly recommended to specify a version bound, since a kernel author
 | 
			
		||||
might push incompatible changes to the `main` branch.
 | 
			
		||||
 | 
			
		||||
## Checking Kernel Availability
 | 
			
		||||
 | 
			
		||||
You can check if a specific kernel is available for your environment:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from kernels import has_kernel
 | 
			
		||||
 | 
			
		||||
# Check if kernel is available for current environment
 | 
			
		||||
is_available = has_kernel("kernels-community/activation")
 | 
			
		||||
print(f"Kernel available: {is_available}")
 | 
			
		||||
```
 | 
			
		||||
							
								
								
									
										41
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
# Kernels CLI Reference
 | 
			
		||||
 | 
			
		||||
## Main Functions
 | 
			
		||||
 | 
			
		||||
### kernels to-wheel
 | 
			
		||||
 | 
			
		||||
We strongly recommend downloading kernels from the Hub using the `kernels`
 | 
			
		||||
package, since this comes with large [benefits](index.md) over using Python
 | 
			
		||||
wheels. That said, some projects may require deployment of kernels as
 | 
			
		||||
wheels. The `kernels` utility provides a simple solution to this. You can
 | 
			
		||||
convert any Hub kernel into a set of wheels with the `to-wheel` command:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
$ kernels to-wheel drbh/img2grey 1.1.2
 | 
			
		||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### kernels upload
 | 
			
		||||
 | 
			
		||||
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
 | 
			
		||||
your kernel builds to the Hub.
 | 
			
		||||
 | 
			
		||||
**Notes**:
 | 
			
		||||
 | 
			
		||||
- This will take care of creating a repository on the Hub with the `repo_id` provided.
 | 
			
		||||
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
 | 
			
		||||
  being uploaded, it will attempt to delete the files existing under it.
 | 
			
		||||
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								docs/source/env.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,10 @@
 | 
			
		||||
# Environment variables
 | 
			
		||||
 | 
			
		||||
## `KERNELS_CACHE`
 | 
			
		||||
 | 
			
		||||
The directory to use as the local kernel cache. If not set, the cache
 | 
			
		||||
of the `huggingface_hub` package is used.
 | 
			
		||||
 | 
			
		||||
## `DISABLE_KERNEL_MAPPING`
 | 
			
		||||
 | 
			
		||||
Disables kernel mappings for [`layers`](layers.md).
 | 
			
		||||
							
								
								
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								docs/source/faq.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
# FAQ
 | 
			
		||||
 | 
			
		||||
## Kernel layers
 | 
			
		||||
 | 
			
		||||
### Why is the kernelization step needed as a separate step?
 | 
			
		||||
 | 
			
		||||
In earlier versions of `kernels`, a layer's `forward` method was replaced
 | 
			
		||||
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
 | 
			
		||||
The new `forward` would dispatch to a kernel based on the device type,
 | 
			
		||||
whether a model was training, etc. However, this approach was
 | 
			
		||||
fundamentally incompatible with `torch.compile` since it relied
 | 
			
		||||
on data-dependent branching.
 | 
			
		||||
 | 
			
		||||
To avoid branching, we have to make dispatch decisions ahead of time,
 | 
			
		||||
which is what the `kernelize` function does.
 | 
			
		||||
 | 
			
		||||
### Why does kernelization only replace `forward` methods?
 | 
			
		||||
 | 
			
		||||
There are some other possible approaches. The first is to completely
 | 
			
		||||
replace existing layers by kernel layers. However, since this would
 | 
			
		||||
permit free-form layer classes, it would be much harder to validate
 | 
			
		||||
that layers are fully compatible with the layers that they are
 | 
			
		||||
replacing. For instance, they could have completely different member
 | 
			
		||||
variables. Besides that, we would also need to hold on to the original
 | 
			
		||||
layers, in case we need to revert to the base layers when the model
 | 
			
		||||
is `kernelize`d again with different options.
 | 
			
		||||
 | 
			
		||||
A second approach would be to make an auxiliary layer that wraps the
 | 
			
		||||
original layer and the kernel layer and dispatches to the kernel layer.
 | 
			
		||||
This wouldn't have the issues of the first approach, because kernel layers
 | 
			
		||||
could be similarly strict as they are now, and we would still have access
 | 
			
		||||
to the original layers when `kernelize`-ing the model again. However,
 | 
			
		||||
this would change the graph structure of the model and would break use
 | 
			
		||||
cases where programs access the model internals (e.g.
 | 
			
		||||
`model.layers[0].attention.query_weight`) or rely on the graph structure
 | 
			
		||||
in other ways.
 | 
			
		||||
 | 
			
		||||
The approach of `forward`-replacement is the least invasive, because
 | 
			
		||||
it preserves the original model graph. It is also reversible, since
 | 
			
		||||
even though the `forward` of a layer _instance_ might be replaced,
 | 
			
		||||
the corresponding class still has the original `forward`.
 | 
			
		||||
							
								
								
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,20 @@
 | 
			
		||||
# Kernels
 | 
			
		||||
 | 
			
		||||
<div align="center">
 | 
			
		||||
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
 | 
			
		||||
</div>
 | 
			
		||||
 | 
			
		||||
The Kernel Hub allows Python libraries and applications to load compute
 | 
			
		||||
kernels directly from the [Hub](https://hf.co/). To support this kind
 | 
			
		||||
of dynamic loading, Hub kernels differ from traditional Python kernel
 | 
			
		||||
packages in that they are made to be:
 | 
			
		||||
 | 
			
		||||
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
 | 
			
		||||
- **Unique**: multiple versions of the same kernel can be loaded in the
 | 
			
		||||
  same Python process.
 | 
			
		||||
- **Compatible**: kernels must support all recent versions of Python and
 | 
			
		||||
  the different PyTorch build configurations (various CUDA versions
 | 
			
		||||
  and C++ ABIs). Furthermore, older C library versions must be supported.
 | 
			
		||||
 | 
			
		||||
You can [search for kernels](https://huggingface.co/models?other=kernel) on
 | 
			
		||||
the Hub.
 | 
			
		||||
							
								
								
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
			
		||||
# Installation
 | 
			
		||||
 | 
			
		||||
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
pip install kernels
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
# Using kernels in a Docker container
 | 
			
		||||
 | 
			
		||||
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
 | 
			
		||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
 | 
			
		||||
```
 | 
			
		||||
@ -1,8 +1,11 @@
 | 
			
		||||
# Kernel requirements
 | 
			
		||||
 | 
			
		||||
Kernels on the Hub must fulfill the requirements outlined on this page.
 | 
			
		||||
Kernels on the Hub must fulfill the requirements outlined on this page. By
 | 
			
		||||
ensuring kernels are compliant, they can be used on a wide range of Linux
 | 
			
		||||
systems and Torch builds.
 | 
			
		||||
 | 
			
		||||
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
 | 
			
		||||
to build conforming kernels.
 | 
			
		||||
to build compliant kernels.
 | 
			
		||||
 | 
			
		||||
## Directory layout
 | 
			
		||||
 | 
			
		||||
@ -10,52 +13,73 @@ A kernel repository on the Hub must contain a `build` directory. This
 | 
			
		||||
directory contains build variants of a kernel in the form of directories
 | 
			
		||||
following the template
 | 
			
		||||
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
 | 
			
		||||
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
 | 
			
		||||
recommended build variants are:
 | 
			
		||||
For example `build/torch26-cxx98-cu118-x86_64-linux`.
 | 
			
		||||
 | 
			
		||||
- `torch25-cxx11-cu118-x86_64-linux`
 | 
			
		||||
- `torch25-cxx11-cu121-x86_64-linux`
 | 
			
		||||
- `torch25-cxx11-cu124-x86_64-linux`
 | 
			
		||||
- `torch25-cxx98-cu118-x86_64-linux`
 | 
			
		||||
- `torch25-cxx98-cu121-x86_64-linux`
 | 
			
		||||
- `torch25-cxx98-cu124-x86_64-linux`
 | 
			
		||||
- `torch26-cxx11-cu118-x86_64-linux`
 | 
			
		||||
- `torch26-cxx11-cu124-x86_64-linux`
 | 
			
		||||
- `torch26-cxx11-cu126-x86_64-linux`
 | 
			
		||||
- `torch26-cxx98-cu118-x86_64-linux`
 | 
			
		||||
- `torch26-cxx98-cu124-x86_64-linux`
 | 
			
		||||
- `torch26-cxx98-cu126-x86_64-linux`
 | 
			
		||||
 | 
			
		||||
This list will be updated as new PyTorch versions are released. Kernels
 | 
			
		||||
that are in pure Python (e.g. Triton kernels) only need to provide a
 | 
			
		||||
single build variant:
 | 
			
		||||
 | 
			
		||||
- `torch-universal`
 | 
			
		||||
 | 
			
		||||
Each variant directory should contain a single directory with the same name
 | 
			
		||||
Each variant directory must contain a single directory with the same name
 | 
			
		||||
as the repository (replacing `-` by `_`). For instance, kernels in the
 | 
			
		||||
`kernels-community/activation` repository have a directories like
 | 
			
		||||
`build/<variant>/activation`. This directory
 | 
			
		||||
must be a Python package with an `__init__.py` file.
 | 
			
		||||
 | 
			
		||||
## Build variants
 | 
			
		||||
 | 
			
		||||
A kernel can be compliant for a specific compute framework (e.g. CUDA) or
 | 
			
		||||
architecture (e.g. x86_64). For compliance with a compute framework and
 | 
			
		||||
architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md)
 | 
			
		||||
must be available for that combination.
 | 
			
		||||
 | 
			
		||||
## Versioning
 | 
			
		||||
 | 
			
		||||
Kernels are versioned on the Hub using Git tags. Version tags must be of
 | 
			
		||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
 | 
			
		||||
to resolve the version constraints.
 | 
			
		||||
 | 
			
		||||
We recommend using [semver](https://semver.org/) to version kernels.
 | 
			
		||||
 | 
			
		||||
## Native Python module
 | 
			
		||||
 | 
			
		||||
Kernels will typically contain a native Python module with precompiled
 | 
			
		||||
compute kernels and bindings. This module must fulfill the following
 | 
			
		||||
requirements:
 | 
			
		||||
compute kernels and bindings. This module must fulfill the requirements
 | 
			
		||||
outlined in this section. For all operating systems, a kernel must not
 | 
			
		||||
have dynamic library dependencies outside:
 | 
			
		||||
 | 
			
		||||
- Torch;
 | 
			
		||||
- CUDA/ROCm libraries installed as dependencies of Torch.
 | 
			
		||||
 | 
			
		||||
### Linux
 | 
			
		||||
 | 
			
		||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
 | 
			
		||||
  for compatibility with Python 3.9 and later.
 | 
			
		||||
- Compatible with glibc 2.27 or later. This means that no symbols
 | 
			
		||||
  from later versions must be used. To archive this, the module should
 | 
			
		||||
  be built against this glibc version. **Warning:** libgcc must also be
 | 
			
		||||
  built against glibc 2.27 to avoid leaking symbols.
 | 
			
		||||
- No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols
 | 
			
		||||
  must be static.
 | 
			
		||||
- No dynamic library dependencies outside Torch or CUDA libraries
 | 
			
		||||
  installed as dependencies of Torch.
 | 
			
		||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
 | 
			
		||||
  This means that the extension **must not** use symbols versions higher than:
 | 
			
		||||
  - GLIBC 2.28
 | 
			
		||||
  - GLIBCXX 3.4.24
 | 
			
		||||
  - CXXABI 1.3.11
 | 
			
		||||
  - GCC 7.0.0
 | 
			
		||||
 | 
			
		||||
(These requirements will be updated as new PyTorch versions are released.)
 | 
			
		||||
These requirements can be checked with the ABI checker (see below).
 | 
			
		||||
 | 
			
		||||
### macOS
 | 
			
		||||
 | 
			
		||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
 | 
			
		||||
  for compatibility with Python 3.9 and later.
 | 
			
		||||
- macOS deployment target 15.0.
 | 
			
		||||
- Metal 3.0 (`-std=metal3.0`).
 | 
			
		||||
 | 
			
		||||
The ABI3 requirement can be checked with the ABI checker (see below).
 | 
			
		||||
 | 
			
		||||
### ABI checker
 | 
			
		||||
 | 
			
		||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
 | 
			
		||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
 | 
			
		||||
$ cargo install kernel-abi-check
 | 
			
		||||
$ kernel-abi-check result/relu/_relu_e87e0ca_dirty.abi3.so
 | 
			
		||||
🐍 Checking for compatibility with manylinux_2_28 and Python ABI version 3.9
 | 
			
		||||
✅ No compatibility issues found
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Torch extension
 | 
			
		||||
 | 
			
		||||
@ -98,10 +122,20 @@ requirements:
 | 
			
		||||
- The `forward` method has a signature that is compatible with the
 | 
			
		||||
  `forward` method that it is extending.
 | 
			
		||||
 | 
			
		||||
There are two exceptions to the _no class variables rule_:
 | 
			
		||||
 | 
			
		||||
1. The `has_backward` variable can be used to indicate whether the layer has
 | 
			
		||||
   a backward pass implemented (`True` when absent).
 | 
			
		||||
2. The `can_torch_compile` variable can be used to indicate whether the layer
 | 
			
		||||
   supports `torch.compile` (`False` when absent).
 | 
			
		||||
 | 
			
		||||
This is an example of a pure layer:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
class SiluAndMul(nn.Module):
 | 
			
		||||
    # This layer does not implement backward.
 | 
			
		||||
    has_backward: bool = False
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        d = x.shape[-1] // 2
 | 
			
		||||
        output_shape = x.shape[:-1] + (d,)
 | 
			
		||||
							
								
								
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										323
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,323 @@
 | 
			
		||||
# Layers
 | 
			
		||||
 | 
			
		||||
A kernel can provide layers in addition to kernel functions. A layer from
 | 
			
		||||
the Hub can replace the `forward` method of an existing layer for a certain
 | 
			
		||||
device type. This makes it possible to provide more performant kernels for
 | 
			
		||||
existing layers.
 | 
			
		||||
 | 
			
		||||
See [Kernel requirements](kernel-requirements.md) for more information on the
 | 
			
		||||
requirements of Hub layers.
 | 
			
		||||
 | 
			
		||||
## Making a layer extensible with kernels from the hub
 | 
			
		||||
 | 
			
		||||
### Using a decorator
 | 
			
		||||
 | 
			
		||||
A layer can be made extensible with the `use_kernel_forward_from_hub`
 | 
			
		||||
decorator. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
@use_kernel_forward_from_hub("SiluAndMul")
 | 
			
		||||
class SiluAndMul(nn.Module):
 | 
			
		||||
    def forward(self, input: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        d = input.shape[-1] // 2
 | 
			
		||||
        return F.silu(input[..., :d]) * input[..., d:]
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The decorator does not change the behavior of the class -- it annotates
 | 
			
		||||
the class with the given name (here `SiluAndMul`). The `kernelize` function
 | 
			
		||||
described below uses this name to look up kernels for the layer.
 | 
			
		||||
 | 
			
		||||
### External layers
 | 
			
		||||
 | 
			
		||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
 | 
			
		||||
decorator can be made extensible using the `replace_kernel_forward_from_hub`
 | 
			
		||||
function:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
from somelibrary import SiluAndMul
 | 
			
		||||
 | 
			
		||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
**Warning:** we strongly recommend using layers with a decorator, since
 | 
			
		||||
it signifies that the maintainer intends to keep the `forward` signature
 | 
			
		||||
compatible with layers from the hub.
 | 
			
		||||
 | 
			
		||||
## Kernelizing a model
 | 
			
		||||
 | 
			
		||||
A model will not use Hub kernels by default, even if it contains extensible
 | 
			
		||||
layers. To enable the use of Hub kernels in the model, it needs to be
 | 
			
		||||
'kernelized' using the `kernelize` function. This function traverses the
 | 
			
		||||
model graph and replaces the `forward` methods of extensible layers for which
 | 
			
		||||
Hub kernels are registered. `kernelize` can be used as follows:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
model = MyModel(...)
 | 
			
		||||
model = kernelize(model, mode=Mode.INFERENCE)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `kernelize` function modifies the model in-place, the model itself is
 | 
			
		||||
returned as a convenience. The `mode` specifies that the model will be used
 | 
			
		||||
in inference. Similarly, you can ask `kernelize` to prepare the model for
 | 
			
		||||
training:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
model = MyModel(...)
 | 
			
		||||
model = kernelize(model, mode=Mode.TRAINING)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
A model that is kernelized for training can also be used for inference, but
 | 
			
		||||
not the other way around. If you want to change the mode of the kernelized
 | 
			
		||||
model, you can just run `kernelize` on the model again with the new mode.
 | 
			
		||||
 | 
			
		||||
If you want to compile a model with `torch.compile`, this should be indicated
 | 
			
		||||
in the mode as well. You can do this by combining `Mode.INFERENCE` or
 | 
			
		||||
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
model = MyModel(...)
 | 
			
		||||
 | 
			
		||||
# Inference
 | 
			
		||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
 | 
			
		||||
 | 
			
		||||
# Training
 | 
			
		||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Kernel device
 | 
			
		||||
 | 
			
		||||
Kernels can be registered per device type. For instance, separate `cuda` and
 | 
			
		||||
`metal` kernels could be registered for the name `SiluAndMul`. By default,
 | 
			
		||||
`kernelize` will try to infer the device type from the model's parameters.
 | 
			
		||||
You can pass the device type to `kernelize` if the device type cannot be
 | 
			
		||||
inferred (e.g. because the model has no parameters):
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
model = MyModel(...)
 | 
			
		||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Fallback `forward`
 | 
			
		||||
 | 
			
		||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
 | 
			
		||||
kernel does not support backward passes or `torch.compile` respectively,
 | 
			
		||||
`kernelize` will fall back to the original, non-kernelized, layer. You
 | 
			
		||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
model = MyModel(...)
 | 
			
		||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This can be useful if you want to guarantee that Hub kernels are used.
 | 
			
		||||
 | 
			
		||||
### Inspecting which kernels are used
 | 
			
		||||
 | 
			
		||||
The kernels that are used are logged at the `INFO` level by `kernelize`.
 | 
			
		||||
See the [Python logging](https://docs.python.org/3/library/logging.html)
 | 
			
		||||
documentation for information on how to configure logging.
 | 
			
		||||
 | 
			
		||||
## Registering a hub kernel for a layer
 | 
			
		||||
 | 
			
		||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
 | 
			
		||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
 | 
			
		||||
the Hub. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        ),
 | 
			
		||||
        "rocm": LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
You can register such a mapping using `register_kernel_mapping`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
register_kernel_mapping(kernel_layer_mapping)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will register the kernel mapping in the current context, which is
 | 
			
		||||
normally global. It is recommended to scope the mapping to where it is
 | 
			
		||||
used with the `use_kernel_mapping` context manager:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
with use_kernel_mapping(kernel_layer_mapping):
 | 
			
		||||
    # Use the layer for which the mapping is applied.
 | 
			
		||||
    model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This ensures that the mapping is not active anymore outside the
 | 
			
		||||
`with`-scope.
 | 
			
		||||
 | 
			
		||||
### Using version bounds
 | 
			
		||||
 | 
			
		||||
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
 | 
			
		||||
You can specify which version of the kernel to download using Python version
 | 
			
		||||
specifiers:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
            version=">=0.0.4,<0.1.0",
 | 
			
		||||
        ),
 | 
			
		||||
        "rocm": LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
            version=">=0.0.4,<0.1.0",
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
 | 
			
		||||
least 4. It is strongly recommended to specify a version bound, since a
 | 
			
		||||
kernel author might push incompatible changes to the `main` branch.
 | 
			
		||||
 | 
			
		||||
### Registering kernels for specific modes
 | 
			
		||||
 | 
			
		||||
You might want to register two different kernels for a particular layer,
 | 
			
		||||
where one kernel is optimized for a specific mode. You can do so by
 | 
			
		||||
registering layer repositories for specific modes. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": {
 | 
			
		||||
          Mode.INFERENCE: LayerRepository(
 | 
			
		||||
              repo_id="kernels-community/activation-inference-optimized",
 | 
			
		||||
              layer_name="SiluAndMul",
 | 
			
		||||
          ),
 | 
			
		||||
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
 | 
			
		||||
              repo_id="kernels-community/activation-training-optimized",
 | 
			
		||||
              layer_name="SiluAndMul",
 | 
			
		||||
          ),
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
The `kernelize` function will attempt to use the following registered
 | 
			
		||||
kernels for a given mode:
 | 
			
		||||
 | 
			
		||||
- `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` →
 | 
			
		||||
  `TRAINING | TORCH_COMPILE` → `FALLBACK`
 | 
			
		||||
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` →
 | 
			
		||||
  `TRAINING | TORCH_COMPILE` → `FALLBACK`
 | 
			
		||||
- `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK`
 | 
			
		||||
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK`
 | 
			
		||||
 | 
			
		||||
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
 | 
			
		||||
is also used when a kernel is registered without a mode, as described in the
 | 
			
		||||
previous section.
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": {
 | 
			
		||||
            Mode.FALLBACK: LayerRepository(
 | 
			
		||||
                repo_id="kernels-community/activation",
 | 
			
		||||
                layer_name="SiluAndMul",
 | 
			
		||||
            ),
 | 
			
		||||
            Mode.INFERENCE: LayerRepository(
 | 
			
		||||
                repo_id="kernels-community/activation-inference-optimized",
 | 
			
		||||
                layer_name="SiluAndMul",
 | 
			
		||||
            ),
 | 
			
		||||
            Mode.TRAINING: LayerRepository(
 | 
			
		||||
                repo_id="kernels-community/activation-training-optimized",
 | 
			
		||||
                layer_name="SiluAndMul",
 | 
			
		||||
            ),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
 | 
			
		||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
 | 
			
		||||
since the other kernels do not support `torch.compile`.
 | 
			
		||||
 | 
			
		||||
### Registering kernels for specific CUDA capabilities
 | 
			
		||||
 | 
			
		||||
Some kernels only work with newer CUDA architectures. For instance, some
 | 
			
		||||
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
 | 
			
		||||
supports registering layers for a range of CUDA capabilities. To do so,
 | 
			
		||||
you need to register the layer for a `Device` with type `cuda` and
 | 
			
		||||
set the supported range of CUDA capabilities with using `CUDAProperties`:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        Device(
 | 
			
		||||
            type="cuda",
 | 
			
		||||
            properties=CUDAProperties(
 | 
			
		||||
                min_capability=75, max_capability=89
 | 
			
		||||
            ),
 | 
			
		||||
        ): LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        ),
 | 
			
		||||
        Device(
 | 
			
		||||
            type="cuda",
 | 
			
		||||
            properties=CUDAProperties(
 | 
			
		||||
                min_capability=90, max_capability=sys.maxsize
 | 
			
		||||
            ),
 | 
			
		||||
        ): LayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation-hopper",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        ),
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Capabilities behave as follows:
 | 
			
		||||
 | 
			
		||||
- The minimum and maximum capabilities are inclusive.
 | 
			
		||||
- When a new kernel is registered with the same min/max capabilities as
 | 
			
		||||
  an existing kernel, the new kernel will replace the old kernel.
 | 
			
		||||
- When there are multiple kernels that support a capability, the kernel
 | 
			
		||||
  with the smaller capability interval will be used. E.g. given:
 | 
			
		||||
  - `KernelA` with `min_capability=80` and `max_capability=89`;
 | 
			
		||||
  - `KernelB` with `min_capability=75` and `max_capability=89`;
 | 
			
		||||
  - `kernelize` runs on a system with capability 8.6.
 | 
			
		||||
 | 
			
		||||
  Then `KernelA` will be used because the interval 80..89 is smaller
 | 
			
		||||
  than 75..89. The motivation is that kernels with smaller ranges
 | 
			
		||||
  tend to be more optimized for a specific set of GPUs. **This behavior
 | 
			
		||||
  might still change in the future.**
 | 
			
		||||
 | 
			
		||||
### Registering kernels for specific ROCm capabilities
 | 
			
		||||
 | 
			
		||||
Registering kernels for the ROCm architecture follows the exact same
 | 
			
		||||
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
 | 
			
		||||
a kernel to a range of ROCm capabilities.
 | 
			
		||||
 | 
			
		||||
### Loading from a local repository for testing
 | 
			
		||||
 | 
			
		||||
The `LocalLayerRepository` class is provided to load a repository from
 | 
			
		||||
a local directory. For example:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
with use_kernel_mapping(
 | 
			
		||||
    {
 | 
			
		||||
        "SiluAndMul": {
 | 
			
		||||
            "cuda": LocalLayerRepository(
 | 
			
		||||
                repo_path="/home/daniel/kernels/activation",
 | 
			
		||||
                package_name="activation",
 | 
			
		||||
                layer_name="SiluAndMul",
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    inherit_mapping=False,
 | 
			
		||||
):
 | 
			
		||||
    kernelize(linear, mode=Mode.INFERENCE)
 | 
			
		||||
```
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Locking kernel versions
 | 
			
		||||
# Locking kernel/layer versions
 | 
			
		||||
 | 
			
		||||
Projects that use `setuptools` can lock the kernel versions that should be
 | 
			
		||||
used. First specify the accepted versions in `pyproject.toml` and make
 | 
			
		||||
@ -13,7 +13,7 @@ build-backend = "setuptools.build_meta"
 | 
			
		||||
"kernels-community/activation" = ">=0.0.1"
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
 | 
			
		||||
Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with
 | 
			
		||||
the locked revisions. The locked revision will be used when loading a kernel with
 | 
			
		||||
`get_locked_kernel`:
 | 
			
		||||
 | 
			
		||||
@ -26,9 +26,27 @@ activation = get_locked_kernel("kernels-community/activation")
 | 
			
		||||
**Note:** the lock file is included in the package metadata, so it will only be visible
 | 
			
		||||
to `kernels` after doing an (editable or regular) installation of your project.
 | 
			
		||||
 | 
			
		||||
## Locked kernel layers
 | 
			
		||||
 | 
			
		||||
Locking is also supported for kernel layers. To use locked layers, register them
 | 
			
		||||
with the `LockedLayerRepository` class:
 | 
			
		||||
 | 
			
		||||
```python
 | 
			
		||||
kernel_layer_mapping = {
 | 
			
		||||
    "SiluAndMul": {
 | 
			
		||||
        "cuda": LockedLayerRepository(
 | 
			
		||||
            repo_id="kernels-community/activation",
 | 
			
		||||
            layer_name="SiluAndMul",
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
register_kernel_mapping(kernel_layer_mapping)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Pre-downloading locked kernels
 | 
			
		||||
 | 
			
		||||
Locked kernels can be pre-downloaded by running `kernel download .` in your
 | 
			
		||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
 | 
			
		||||
project directory. This will download the kernels to your local Hugging Face
 | 
			
		||||
Hub cache.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										63
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							@ -51,30 +51,50 @@
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "nixpkgs": {
 | 
			
		||||
    "hf-nix": {
 | 
			
		||||
      "inputs": {
 | 
			
		||||
        "flake-compat": "flake-compat",
 | 
			
		||||
        "flake-utils": "flake-utils_2",
 | 
			
		||||
        "nixpkgs": "nixpkgs"
 | 
			
		||||
      },
 | 
			
		||||
      "locked": {
 | 
			
		||||
        "lastModified": 1737453259,
 | 
			
		||||
        "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
 | 
			
		||||
        "owner": "danieldk",
 | 
			
		||||
        "repo": "nixpkgs",
 | 
			
		||||
        "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
 | 
			
		||||
        "lastModified": 1754038838,
 | 
			
		||||
        "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
 | 
			
		||||
        "owner": "huggingface",
 | 
			
		||||
        "repo": "hf-nix",
 | 
			
		||||
        "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      },
 | 
			
		||||
      "original": {
 | 
			
		||||
        "owner": "danieldk",
 | 
			
		||||
        "ref": "outlines-v0.1.4-tgi",
 | 
			
		||||
        "owner": "huggingface",
 | 
			
		||||
        "repo": "hf-nix",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "nixpkgs": {
 | 
			
		||||
      "locked": {
 | 
			
		||||
        "lastModified": 1752785354,
 | 
			
		||||
        "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
 | 
			
		||||
        "owner": "nixos",
 | 
			
		||||
        "repo": "nixpkgs",
 | 
			
		||||
        "rev": "d38025438a6ee456758dc03188ca6873a415463b",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      },
 | 
			
		||||
      "original": {
 | 
			
		||||
        "owner": "nixos",
 | 
			
		||||
        "repo": "nixpkgs",
 | 
			
		||||
        "rev": "d38025438a6ee456758dc03188ca6873a415463b",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "root": {
 | 
			
		||||
      "inputs": {
 | 
			
		||||
        "flake-utils": "flake-utils",
 | 
			
		||||
        "hf-nix": "hf-nix",
 | 
			
		||||
        "nixpkgs": [
 | 
			
		||||
          "tgi-nix",
 | 
			
		||||
          "hf-nix",
 | 
			
		||||
          "nixpkgs"
 | 
			
		||||
        ],
 | 
			
		||||
        "tgi-nix": "tgi-nix"
 | 
			
		||||
        ]
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "systems": {
 | 
			
		||||
@ -106,27 +126,6 @@
 | 
			
		||||
        "repo": "default",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      }
 | 
			
		||||
    },
 | 
			
		||||
    "tgi-nix": {
 | 
			
		||||
      "inputs": {
 | 
			
		||||
        "flake-compat": "flake-compat",
 | 
			
		||||
        "flake-utils": "flake-utils_2",
 | 
			
		||||
        "nixpkgs": "nixpkgs"
 | 
			
		||||
      },
 | 
			
		||||
      "locked": {
 | 
			
		||||
        "lastModified": 1741617161,
 | 
			
		||||
        "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
 | 
			
		||||
        "owner": "huggingface",
 | 
			
		||||
        "repo": "text-generation-inference-nix",
 | 
			
		||||
        "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      },
 | 
			
		||||
      "original": {
 | 
			
		||||
        "owner": "huggingface",
 | 
			
		||||
        "ref": "kernels-0.2.0",
 | 
			
		||||
        "repo": "text-generation-inference-nix",
 | 
			
		||||
        "type": "github"
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "root": "root",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										20
									
								
								flake.nix
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								flake.nix
									
									
									
									
									
								
							@ -1,7 +1,7 @@
 | 
			
		||||
{
 | 
			
		||||
  inputs = {
 | 
			
		||||
    tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
 | 
			
		||||
    nixpkgs.follows = "tgi-nix/nixpkgs";
 | 
			
		||||
    hf-nix.url = "github:huggingface/hf-nix";
 | 
			
		||||
    nixpkgs.follows = "hf-nix/nixpkgs";
 | 
			
		||||
    flake-utils.url = "github:numtide/flake-utils";
 | 
			
		||||
  };
 | 
			
		||||
  outputs =
 | 
			
		||||
@ -9,23 +9,27 @@
 | 
			
		||||
      self,
 | 
			
		||||
      nixpkgs,
 | 
			
		||||
      flake-utils,
 | 
			
		||||
      tgi-nix,
 | 
			
		||||
      hf-nix,
 | 
			
		||||
    }:
 | 
			
		||||
    flake-utils.lib.eachDefaultSystem (
 | 
			
		||||
      system:
 | 
			
		||||
      let
 | 
			
		||||
        pkgs = import nixpkgs {
 | 
			
		||||
          inherit system;
 | 
			
		||||
          inherit (tgi-nix.lib) config;
 | 
			
		||||
          config = hf-nix.lib.config system;
 | 
			
		||||
          overlays = [
 | 
			
		||||
            tgi-nix.overlays.default
 | 
			
		||||
            hf-nix.overlays.default
 | 
			
		||||
          ];
 | 
			
		||||
        };
 | 
			
		||||
      in
 | 
			
		||||
      {
 | 
			
		||||
        formatter = pkgs.nixfmt-rfc-style;
 | 
			
		||||
        formatter = pkgs.nixfmt-tree;
 | 
			
		||||
        devShells = with pkgs; rec {
 | 
			
		||||
          default = mkShell {
 | 
			
		||||
            nativeBuildInputs = [
 | 
			
		||||
              # For hf-doc-builder.
 | 
			
		||||
              nodejs
 | 
			
		||||
            ];
 | 
			
		||||
            buildInputs =
 | 
			
		||||
              [
 | 
			
		||||
                black
 | 
			
		||||
@ -34,10 +38,14 @@
 | 
			
		||||
                ruff
 | 
			
		||||
              ]
 | 
			
		||||
              ++ (with python3.pkgs; [
 | 
			
		||||
                docutils
 | 
			
		||||
                huggingface-hub
 | 
			
		||||
                mktestdocs
 | 
			
		||||
                pytest
 | 
			
		||||
                pytest-benchmark
 | 
			
		||||
                pyyaml
 | 
			
		||||
                torch
 | 
			
		||||
                types-pyyaml
 | 
			
		||||
                venvShellHook
 | 
			
		||||
              ]);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
[project]
 | 
			
		||||
name = "kernels"
 | 
			
		||||
version = "0.3.1"
 | 
			
		||||
version = "0.10.2"
 | 
			
		||||
description = "Download compute kernels"
 | 
			
		||||
authors = [
 | 
			
		||||
  { name = "OlivierDehaene", email = "olivier@huggingface.co" },
 | 
			
		||||
@ -8,13 +8,14 @@ authors = [
 | 
			
		||||
  { name = "David Holtz", email = "david@huggingface.co" },
 | 
			
		||||
  { name = "Nicolas Patry", email = "nicolas@huggingface.co" },
 | 
			
		||||
]
 | 
			
		||||
license = { text = "Apache-2.0" }
 | 
			
		||||
readme = "README.md"
 | 
			
		||||
requires-python = ">= 3.9"
 | 
			
		||||
dependencies = [
 | 
			
		||||
  "huggingface-hub>=0.26.3",
 | 
			
		||||
  "packaging>=24.2",
 | 
			
		||||
  "tomli>=2.0.1; python_version<'3.11'",
 | 
			
		||||
  "torch>=2.5",
 | 
			
		||||
  "huggingface_hub>=0.26.0,<2.0",
 | 
			
		||||
  "packaging>=20.0",
 | 
			
		||||
  "pyyaml>=6",
 | 
			
		||||
  "tomli>=2.0; python_version<'3.11'",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[build-system]
 | 
			
		||||
@ -23,10 +24,19 @@ build-backend = "setuptools.build_meta"
 | 
			
		||||
 | 
			
		||||
[dependency-groups]
 | 
			
		||||
dev = [
 | 
			
		||||
  "mypy == 1.14.1",
 | 
			
		||||
  "pytest >=8",
 | 
			
		||||
  "mktestdocs>=0.2.5",
 | 
			
		||||
  "mypy>=1.15.0",
 | 
			
		||||
  "pytest>=8",
 | 
			
		||||
  # Whatever version is compatible with pytest.
 | 
			
		||||
  "pytest-benchmark",
 | 
			
		||||
  "torch>=2.5",
 | 
			
		||||
  "types-pyyaml"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[project.optional-dependencies]
 | 
			
		||||
torch = ["torch"]
 | 
			
		||||
docs = [
 | 
			
		||||
  "hf-doc-builder",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[project.scripts]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,8 @@
 | 
			
		||||
[pytest]
 | 
			
		||||
markers =
 | 
			
		||||
    cuda_only: marks tests that should only hosts with CUDA GPUs
 | 
			
		||||
    rocm_only: marks tests that should only run on hosts with ROCm GPUs
 | 
			
		||||
    darwin_only: marks tests that should only run on macOS
 | 
			
		||||
    xpu_only: marks tests that should only run on hosts with Intel XPUs
 | 
			
		||||
    token: enable tests that require a write token
 | 
			
		||||
    is_staging_test: Marks tests that should only run on a staging environment
 | 
			
		||||
@ -1,23 +1,46 @@
 | 
			
		||||
import importlib.metadata
 | 
			
		||||
 | 
			
		||||
__version__ = importlib.metadata.version("kernels")
 | 
			
		||||
 | 
			
		||||
from kernels.layer import (
 | 
			
		||||
    CUDAProperties,
 | 
			
		||||
    Device,
 | 
			
		||||
    LayerRepository,
 | 
			
		||||
    LocalLayerRepository,
 | 
			
		||||
    LockedLayerRepository,
 | 
			
		||||
    Mode,
 | 
			
		||||
    kernelize,
 | 
			
		||||
    register_kernel_mapping,
 | 
			
		||||
    replace_kernel_forward_from_hub,
 | 
			
		||||
    use_kernel_forward_from_hub,
 | 
			
		||||
    use_kernel_mapping,
 | 
			
		||||
)
 | 
			
		||||
from kernels.utils import (
 | 
			
		||||
    get_kernel,
 | 
			
		||||
    get_local_kernel,
 | 
			
		||||
    get_locked_kernel,
 | 
			
		||||
    has_kernel,
 | 
			
		||||
    install_kernel,
 | 
			
		||||
    load_kernel,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "get_kernel",
 | 
			
		||||
    "get_locked_kernel",
 | 
			
		||||
    "load_kernel",
 | 
			
		||||
    "install_kernel",
 | 
			
		||||
    "use_kernel_forward_from_hub",
 | 
			
		||||
    "register_kernel_mapping",
 | 
			
		||||
    "LayerRepository",
 | 
			
		||||
    "__version__",
 | 
			
		||||
    "CUDAProperties",
 | 
			
		||||
    "Device",
 | 
			
		||||
    "LayerRepository",
 | 
			
		||||
    "LocalLayerRepository",
 | 
			
		||||
    "LockedLayerRepository",
 | 
			
		||||
    "Mode",
 | 
			
		||||
    "get_kernel",
 | 
			
		||||
    "get_local_kernel",
 | 
			
		||||
    "get_locked_kernel",
 | 
			
		||||
    "has_kernel",
 | 
			
		||||
    "install_kernel",
 | 
			
		||||
    "kernelize",
 | 
			
		||||
    "load_kernel",
 | 
			
		||||
    "register_kernel_mapping",
 | 
			
		||||
    "replace_kernel_forward_from_hub",
 | 
			
		||||
    "use_kernel_forward_from_hub",
 | 
			
		||||
    "use_kernel_mapping",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										200
									
								
								src/kernels/_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,200 @@
 | 
			
		||||
# AVL-balanced interval trees. We could use the intervaltree
 | 
			
		||||
# packages, but it seems unmaintained and does not have type
 | 
			
		||||
# annotations.
 | 
			
		||||
 | 
			
		||||
from typing import Generic, List, Optional, Tuple, TypeVar
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _Node(Generic[T]):
 | 
			
		||||
    """A node in the interval tree."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, start: int, end: int, data: T):
 | 
			
		||||
        self.start: int = start
 | 
			
		||||
        self.end: int = end
 | 
			
		||||
        self.data: T = data
 | 
			
		||||
        self.max_end: int = end
 | 
			
		||||
        self.left: Optional["_Node[T]"] = None
 | 
			
		||||
        self.right: Optional["_Node[T]"] = None
 | 
			
		||||
        self.height: int = 1
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f"Node({self.start}, {self.end})"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IntervalTree(Generic[T]):
 | 
			
		||||
    """A data structure to hold and query (unique) intervals."""
 | 
			
		||||
 | 
			
		||||
    root: Optional[_Node[T]]
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.root = None
 | 
			
		||||
 | 
			
		||||
    def insert(self, start: int, end: int, data: T) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Inserts a new interval into the tree.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            start: The starting point of the interval.
 | 
			
		||||
            end: The ending point of the interval.
 | 
			
		||||
            data: The data associated with this interval.
 | 
			
		||||
        """
 | 
			
		||||
        self.root = self._insert(self.root, start, end, data)
 | 
			
		||||
 | 
			
		||||
    def _get_height(self, node: Optional[_Node[T]]) -> int:
 | 
			
		||||
        if not node:
 | 
			
		||||
            return 0
 | 
			
		||||
        return node.height
 | 
			
		||||
 | 
			
		||||
    def _get_balance(self, node: Optional[_Node[T]]) -> int:
 | 
			
		||||
        if not node:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self._get_height(node.left) - self._get_height(node.right)
 | 
			
		||||
 | 
			
		||||
    def _update_node_attributes(self, node: _Node[T]) -> None:
 | 
			
		||||
        node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
 | 
			
		||||
        node.max_end = node.end
 | 
			
		||||
        if node.left:
 | 
			
		||||
            node.max_end = max(node.max_end, node.left.max_end)
 | 
			
		||||
        if node.right:
 | 
			
		||||
            node.max_end = max(node.max_end, node.right.max_end)
 | 
			
		||||
 | 
			
		||||
    def _right_rotate(self, y: _Node[T]) -> _Node[T]:
 | 
			
		||||
        """Performs a right rotation."""
 | 
			
		||||
        x = y.left
 | 
			
		||||
        assert x is not None
 | 
			
		||||
        T2 = x.right
 | 
			
		||||
 | 
			
		||||
        x.right = y
 | 
			
		||||
        y.left = T2
 | 
			
		||||
 | 
			
		||||
        self._update_node_attributes(y)
 | 
			
		||||
        self._update_node_attributes(x)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def _left_rotate(self, x: _Node[T]) -> _Node[T]:
 | 
			
		||||
        """Performs a left rotation."""
 | 
			
		||||
        y = x.right
 | 
			
		||||
        assert y is not None
 | 
			
		||||
        T2 = y.left
 | 
			
		||||
 | 
			
		||||
        y.left = x
 | 
			
		||||
        x.right = T2
 | 
			
		||||
 | 
			
		||||
        self._update_node_attributes(x)
 | 
			
		||||
        self._update_node_attributes(y)
 | 
			
		||||
 | 
			
		||||
        return y
 | 
			
		||||
 | 
			
		||||
    def _insert(
 | 
			
		||||
        self, node: Optional[_Node[T]], start: int, end: int, data: T
 | 
			
		||||
    ) -> _Node[T]:
 | 
			
		||||
        """Recursive helper to insert a new node and balance the tree."""
 | 
			
		||||
        if not node:
 | 
			
		||||
            return _Node(start, end, data)
 | 
			
		||||
 | 
			
		||||
        # Replace the data if the interval already exists.
 | 
			
		||||
        if start == node.start and end == node.end:
 | 
			
		||||
            node.data = data
 | 
			
		||||
            return node
 | 
			
		||||
 | 
			
		||||
        if start < node.start:
 | 
			
		||||
            node.left = self._insert(node.left, start, end, data)
 | 
			
		||||
        else:
 | 
			
		||||
            node.right = self._insert(node.right, start, end, data)
 | 
			
		||||
 | 
			
		||||
        self._update_node_attributes(node)
 | 
			
		||||
 | 
			
		||||
        balance = self._get_balance(node)
 | 
			
		||||
 | 
			
		||||
        # Left Left Case
 | 
			
		||||
        if balance > 1 and node.left and start < node.left.start:
 | 
			
		||||
            return self._right_rotate(node)
 | 
			
		||||
 | 
			
		||||
        # Right Right Case
 | 
			
		||||
        if balance < -1 and node.right and start >= node.right.start:
 | 
			
		||||
            return self._left_rotate(node)
 | 
			
		||||
 | 
			
		||||
        # Left Right Case
 | 
			
		||||
        if balance > 1 and node.left and start >= node.left.start:
 | 
			
		||||
            node.left = self._left_rotate(node.left)
 | 
			
		||||
            return self._right_rotate(node)
 | 
			
		||||
 | 
			
		||||
        # Right Left Case
 | 
			
		||||
        if balance < -1 and node.right and start < node.right.start:
 | 
			
		||||
            node.right = self._right_rotate(node.right)
 | 
			
		||||
            return self._left_rotate(node)
 | 
			
		||||
 | 
			
		||||
        return node
 | 
			
		||||
 | 
			
		||||
    def search(self, point: int) -> List[T]:
 | 
			
		||||
        """
 | 
			
		||||
        Searches for all intervals that contain the given point.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            point: The point to search for.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A list of data items from all matching intervals.
 | 
			
		||||
        """
 | 
			
		||||
        results: List[T] = []
 | 
			
		||||
        self._search(self.root, point, results)
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
 | 
			
		||||
        """Recursive helper to find all overlapping intervals."""
 | 
			
		||||
        if node is None or point > node.max_end:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if node.left:
 | 
			
		||||
            self._search(node.left, point, results)
 | 
			
		||||
 | 
			
		||||
        if node.start <= point <= node.end:
 | 
			
		||||
            results.append(node.data)
 | 
			
		||||
 | 
			
		||||
        if point >= node.start and node.right:
 | 
			
		||||
            self._search(node.right, point, results)
 | 
			
		||||
 | 
			
		||||
    def find_smallest_interval(self, point: int) -> Optional[T]:
 | 
			
		||||
        """
 | 
			
		||||
        Finds the item with the most specific (smallest) range for a given point.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            point: The capability to look up.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The data of the best-matching item, or None if no match is found.
 | 
			
		||||
        """
 | 
			
		||||
        matches: List[Tuple[int, int, T]] = []
 | 
			
		||||
        self._find_with_intervals(self.root, point, matches)
 | 
			
		||||
 | 
			
		||||
        if not matches:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # Return the smallest interval, sort by memory location when
 | 
			
		||||
        # there are multiple matches with the same interval size. This
 | 
			
		||||
        # is just to ensure that we can compare against a trivial
 | 
			
		||||
        # implementation in tests.
 | 
			
		||||
        best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
 | 
			
		||||
        return best_match[2]
 | 
			
		||||
 | 
			
		||||
    def _find_with_intervals(
 | 
			
		||||
        self,
 | 
			
		||||
        node: Optional[_Node[T]],
 | 
			
		||||
        point: int,
 | 
			
		||||
        results: List[Tuple[int, int, T]],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """A modified search that collects interval ranges along with data."""
 | 
			
		||||
        if node is None or point > node.max_end:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if node.left:
 | 
			
		||||
            self._find_with_intervals(node.left, point, results)
 | 
			
		||||
 | 
			
		||||
        if node.start <= point <= node.end:
 | 
			
		||||
            results.append((node.start, node.end, node.data))
 | 
			
		||||
 | 
			
		||||
        if point >= node.start and node.right:
 | 
			
		||||
            self._find_with_intervals(node.right, point, results)
 | 
			
		||||
							
								
								
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										751
									
								
								src/kernels/_vendored/convert_rst_to_mdx.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,751 @@
 | 
			
		||||
# coding=utf-8
 | 
			
		||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
 | 
			
		||||
# Re pattern to catch things inside ` ` in :obj:`thing`.
 | 
			
		||||
_re_obj = re.compile(r":obj:`([^`]+)`")
 | 
			
		||||
# Re pattern to catch things inside ` ` in :math:`thing`.
 | 
			
		||||
_re_math = re.compile(r":math:`([^`]+)`")
 | 
			
		||||
# Re pattern to catch things between single backquotes.
 | 
			
		||||
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
 | 
			
		||||
# Re pattern to catch things between double backquotes.
 | 
			
		||||
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
 | 
			
		||||
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
 | 
			
		||||
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_rst_formatting(text):
 | 
			
		||||
    """
 | 
			
		||||
    Convert rst syntax for formatting to markdown in a given text.
 | 
			
		||||
    """
 | 
			
		||||
    # Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
 | 
			
		||||
    # (to not be caught by the italic conversion).
 | 
			
		||||
    text = _re_func_class.sub(r"[``\1``]", text)
 | 
			
		||||
    # Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
 | 
			
		||||
    # (to not be caught by the italic conversion).
 | 
			
		||||
    text = _re_obj.sub(r"``\1``", text)
 | 
			
		||||
    # Remove :math: markers.
 | 
			
		||||
    text = _re_math.sub(r"\\\\(\1\\\\)", text)
 | 
			
		||||
    # Convert content in single backquotes to italic.
 | 
			
		||||
    text = _re_single_backquotes.sub(r"\1*\2*\3", text)
 | 
			
		||||
    # Convert content in double backquotes to single backquotes.
 | 
			
		||||
    text = _re_double_backquotes.sub(r"\1`\2`\3", text)
 | 
			
		||||
    # Remove remaining ::
 | 
			
		||||
    text = re.sub(r"::\n", "", text)
 | 
			
		||||
 | 
			
		||||
    # Remove new lines inside blocks in backsticks as they will be kept.
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    in_code = False
 | 
			
		||||
    text = None
 | 
			
		||||
    for line in lines:
 | 
			
		||||
        if in_code:
 | 
			
		||||
            splits = line.split("`")
 | 
			
		||||
            in_code = len(splits) > 1 and len(splits) % 2 == 1
 | 
			
		||||
            if len(splits) == 1:
 | 
			
		||||
                # Some forgotten lone backstick
 | 
			
		||||
                text += "\n" + line
 | 
			
		||||
            else:
 | 
			
		||||
                text += " " + line.lstrip()
 | 
			
		||||
        else:
 | 
			
		||||
            if text is not None:
 | 
			
		||||
                text += "\n" + line
 | 
			
		||||
            else:
 | 
			
		||||
                text = line
 | 
			
		||||
            splits = line.split("`")
 | 
			
		||||
            in_code = len(splits) % 2 == 0
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Re pattern to catch description and url in links of the form `description <url>`_.
 | 
			
		||||
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
 | 
			
		||||
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
 | 
			
		||||
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
 | 
			
		||||
# Re pattern to catch reference in links of the form :doc:`reference`.
 | 
			
		||||
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
 | 
			
		||||
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
 | 
			
		||||
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
 | 
			
		||||
# Re pattern to catch reference in links of the form :ref:`reference`.
 | 
			
		||||
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
 | 
			
		||||
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
 | 
			
		||||
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_rst_links(text, page_info):
 | 
			
		||||
    """
 | 
			
		||||
    Convert the rst links in text to markdown.
 | 
			
		||||
    """
 | 
			
		||||
    if "package_name" not in page_info:
 | 
			
		||||
        raise ValueError("`page_info` must contain at least the package_name.")
 | 
			
		||||
    package_name = page_info["package_name"]
 | 
			
		||||
    version = page_info.get("version", "main")
 | 
			
		||||
    language = page_info.get("language", "en")
 | 
			
		||||
    no_prefix = page_info.get("no_prefix", False)
 | 
			
		||||
 | 
			
		||||
    prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
 | 
			
		||||
    # Links of the form :doc:`page`
 | 
			
		||||
    text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
 | 
			
		||||
    # Links of the form :doc:`text <page>`
 | 
			
		||||
    text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
 | 
			
		||||
 | 
			
		||||
    if "page" in page_info and not no_prefix:
 | 
			
		||||
        page = str(page_info["page"])
 | 
			
		||||
        if page.endswith(".html"):
 | 
			
		||||
            page = page[:-5]
 | 
			
		||||
        prefix = f"{prefix}{page}"
 | 
			
		||||
    else:
 | 
			
		||||
        prefix = ""
 | 
			
		||||
    # Refs of the form :ref:`page`
 | 
			
		||||
    text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
 | 
			
		||||
    # Refs of the form :ref:`text <page>`
 | 
			
		||||
    text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
 | 
			
		||||
 | 
			
		||||
    # Links with a prefix
 | 
			
		||||
    # TODO: when it exists, use the API to deal with prefix links properly.
 | 
			
		||||
    prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
 | 
			
		||||
    text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
 | 
			
		||||
    # Other links
 | 
			
		||||
    text = _re_links.sub(r"[\1](\2)", text)
 | 
			
		||||
    # Relative links or Transformers links need to remove the .html
 | 
			
		||||
    if (
 | 
			
		||||
        "(https://https://huggingface.co/" in text
 | 
			
		||||
        or re.search(r"\(\.+/", text) is not None
 | 
			
		||||
    ):
 | 
			
		||||
        text = text.replace(".html", "")
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Re pattern that catches examples blocks of the form `Example::`.
 | 
			
		||||
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
 | 
			
		||||
# Re pattern that catches rst blocks of the form `.. block_name::`.
 | 
			
		||||
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
 | 
			
		||||
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
 | 
			
		||||
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_empty_line(line):
 | 
			
		||||
    return len(line) == 0 or line.isspace()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def find_indent(line):
 | 
			
		||||
    """
 | 
			
		||||
    Returns the number of spaces that start a line indent.
 | 
			
		||||
    """
 | 
			
		||||
    search = re.search(r"^(\s*)(?:\S|$)", line)
 | 
			
		||||
    if search is None:
 | 
			
		||||
        return 0
 | 
			
		||||
    return len(search.groups()[0])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_special_chars(text):
 | 
			
		||||
    """
 | 
			
		||||
    Converts { and < that have special meanings in MDX.
 | 
			
		||||
    """
 | 
			
		||||
    text = text.replace("{", "&lcub;")
 | 
			
		||||
    # We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
 | 
			
		||||
    text = re.sub(
 | 
			
		||||
        r"<(img|br|hr|Youtube)", r"LTHTML\1", text
 | 
			
		||||
    )  # html void elements with no closing counterpart
 | 
			
		||||
    _re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
 | 
			
		||||
    while _re_lt_html.search(text):
 | 
			
		||||
        text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
 | 
			
		||||
    text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text)
 | 
			
		||||
    text = text.replace("LTHTML", "<")
 | 
			
		||||
    return text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_options(block_content):
 | 
			
		||||
    """
 | 
			
		||||
    Parses the option in some rst block content.
 | 
			
		||||
    """
 | 
			
		||||
    block_lines = block_content.split("\n")
 | 
			
		||||
    block_indent = find_indent(block_lines[0])
 | 
			
		||||
    current_option = None
 | 
			
		||||
    result = {}
 | 
			
		||||
    for line in block_lines:
 | 
			
		||||
        if _re_rst_option.search(line) is not None:
 | 
			
		||||
            current_option, value = _re_rst_option.search(line).groups()
 | 
			
		||||
            result[current_option] = value.lstrip()
 | 
			
		||||
        elif find_indent(line) > block_indent:
 | 
			
		||||
            result[current_option] += " " + line.lstrip()
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_min_indent(text, min_indent):
 | 
			
		||||
    """
 | 
			
		||||
    Make sure all lines in a text are have a minimum indentation.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        text (`str`): The text to treat.
 | 
			
		||||
        min_indent (`int`): The minimal indentation.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `str`: The processed text.
 | 
			
		||||
    """
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    idx = 0
 | 
			
		||||
    while idx < len(lines):
 | 
			
		||||
        if is_empty_line(lines[idx]):
 | 
			
		||||
            idx += 1
 | 
			
		||||
            continue
 | 
			
		||||
        indent = find_indent(lines[idx])
 | 
			
		||||
        if indent < min_indent:
 | 
			
		||||
            while idx < len(lines) and (
 | 
			
		||||
                find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
 | 
			
		||||
            ):
 | 
			
		||||
                if not is_empty_line(lines[idx]):
 | 
			
		||||
                    lines[idx] = " " * (min_indent - indent) + lines[idx]
 | 
			
		||||
                idx += 1
 | 
			
		||||
        else:
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
    return "\n".join(lines)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_rst_blocks(text, page_info):
 | 
			
		||||
    """
 | 
			
		||||
    Converts rst special blocks (examples, notes) into MDX.
 | 
			
		||||
    """
 | 
			
		||||
    if "package_name" not in page_info:
 | 
			
		||||
        raise ValueError("`page_info` must contain at least the package_name.")
 | 
			
		||||
    package_name = page_info["package_name"]
 | 
			
		||||
    version = page_info.get("version", "main")
 | 
			
		||||
    language = page_info.get("language", "en")
 | 
			
		||||
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    idx = 0
 | 
			
		||||
    new_lines = []
 | 
			
		||||
    while idx < len(lines):
 | 
			
		||||
        block_type = None
 | 
			
		||||
        block_info = None
 | 
			
		||||
        if _re_block.search(lines[idx]) is not None:
 | 
			
		||||
            block_type = _re_block.search(lines[idx]).groups()[0]
 | 
			
		||||
            if _re_block_info.search(lines[idx]) is not None:
 | 
			
		||||
                block_info = _re_block_info.search(lines[idx]).groups()[0]
 | 
			
		||||
        elif _re_example.search(lines[idx]) is not None:
 | 
			
		||||
            block_type = "code-block-example"
 | 
			
		||||
            block_info = "python"
 | 
			
		||||
            example_name = _re_example.search(lines[idx]).groups()[0]
 | 
			
		||||
            new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
 | 
			
		||||
        elif lines[idx].strip() == "..":
 | 
			
		||||
            block_type = "comment"
 | 
			
		||||
        elif lines[idx].strip() == "::":
 | 
			
		||||
            block_type = "code-block"
 | 
			
		||||
 | 
			
		||||
        if block_type is not None:
 | 
			
		||||
            block_indent = find_indent(lines[idx])
 | 
			
		||||
            # Find the next nonempty line
 | 
			
		||||
            idx += 1
 | 
			
		||||
            while idx < len(lines) and is_empty_line(lines[idx]):
 | 
			
		||||
                idx += 1
 | 
			
		||||
            # Grab the indent of the return line, this block will stop when we unindent under it (or has already)
 | 
			
		||||
            example_indent = (
 | 
			
		||||
                find_indent(lines[idx]) if idx < len(lines) else block_indent
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if example_indent == block_indent:
 | 
			
		||||
                block_content = ""
 | 
			
		||||
            else:
 | 
			
		||||
                block_lines = []
 | 
			
		||||
                while idx < len(lines) and (
 | 
			
		||||
                    is_empty_line(lines[idx])
 | 
			
		||||
                    or find_indent(lines[idx]) >= example_indent
 | 
			
		||||
                ):
 | 
			
		||||
                    block_lines.append(lines[idx][example_indent:])
 | 
			
		||||
                    idx += 1
 | 
			
		||||
                block_content = "\n".join(block_lines)
 | 
			
		||||
 | 
			
		||||
            if block_type in ["code", "code-block"]:
 | 
			
		||||
                prefix = "```" if block_info is None else f"```{block_info}"
 | 
			
		||||
                new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
 | 
			
		||||
            elif block_type == "code-block-example":
 | 
			
		||||
                prefix = f"<example>```{block_info}"
 | 
			
		||||
                new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
 | 
			
		||||
            elif block_type == "note":
 | 
			
		||||
                new_lines.append(
 | 
			
		||||
                    apply_min_indent(
 | 
			
		||||
                        f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            elif block_type == "warning":
 | 
			
		||||
                new_lines.append(
 | 
			
		||||
                    apply_min_indent(
 | 
			
		||||
                        "<Tip warning={true}>\n\n"
 | 
			
		||||
                        + f"{block_content.strip()}\n\n</Tip>\n",
 | 
			
		||||
                        block_indent,
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            elif block_type == "raw":
 | 
			
		||||
                new_lines.append(block_content.strip() + "\n")
 | 
			
		||||
            elif block_type == "math":
 | 
			
		||||
                new_lines.append(f"$${block_content.strip()}$$\n")
 | 
			
		||||
            elif block_type == "comment":
 | 
			
		||||
                new_lines.append(f"<!--{block_content.strip()}\n-->\n")
 | 
			
		||||
            elif block_type == "autofunction":
 | 
			
		||||
                if block_info is not None:
 | 
			
		||||
                    new_lines.append(f"[[autodoc]] {block_info}\n")
 | 
			
		||||
            elif block_type == "autoclass":
 | 
			
		||||
                if block_info is not None:
 | 
			
		||||
                    block = f"[[autodoc]] {block_info}\n"
 | 
			
		||||
                    options = parse_options(block_content)
 | 
			
		||||
                    if "special-members" in options:
 | 
			
		||||
                        special_members = options["special-members"].split(", ")
 | 
			
		||||
                        for special_member in special_members:
 | 
			
		||||
                            block += f"    - {special_member}\n"
 | 
			
		||||
                    if "members" in options:
 | 
			
		||||
                        members = options["members"]
 | 
			
		||||
                        if len(members) == 0:
 | 
			
		||||
                            block += "    - all\n"
 | 
			
		||||
                        else:
 | 
			
		||||
                            for member in members.split(", "):
 | 
			
		||||
                                block += f"    - {member}\n"
 | 
			
		||||
                    new_lines.append(block)
 | 
			
		||||
            elif block_type == "image":
 | 
			
		||||
                options = parse_options(block_content)
 | 
			
		||||
                target = options.pop("target", None)
 | 
			
		||||
                if block_info is not None:
 | 
			
		||||
                    options["src"] = block_info
 | 
			
		||||
                else:
 | 
			
		||||
                    if target is None:
 | 
			
		||||
                        raise ValueError("Image source not defined.")
 | 
			
		||||
                    options["src"] = target
 | 
			
		||||
                # Adapt path
 | 
			
		||||
                options["src"] = options["src"].replace(
 | 
			
		||||
                    "/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
 | 
			
		||||
                )
 | 
			
		||||
                html_code = " ".join(
 | 
			
		||||
                    [f'{key}="{value}"' for key, value in options.items()]
 | 
			
		||||
                )
 | 
			
		||||
                new_lines.append(f"<img {html_code}/>\n")
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                new_lines.append(
 | 
			
		||||
                    f"{block_type},{block_info}\n{block_content.rstrip()}\n"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            new_lines.append(lines[idx])
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
    return "\n".join(new_lines)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Re pattern that catches rst args blocks of the form `Parameters:`.
 | 
			
		||||
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
 | 
			
		||||
# Re pattern that catches return blocks of the form `Return:`.
 | 
			
		||||
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_return_line(line):
 | 
			
		||||
    """
 | 
			
		||||
    Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
 | 
			
		||||
    """
 | 
			
		||||
    splits_on_colon = line.split(":")
 | 
			
		||||
    idx = 1
 | 
			
		||||
    while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
 | 
			
		||||
        idx += 2
 | 
			
		||||
    if idx >= len(splits_on_colon):
 | 
			
		||||
        if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
 | 
			
		||||
            return line, ""
 | 
			
		||||
        return None, line
 | 
			
		||||
    return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_raise_line(line):
 | 
			
		||||
    """
 | 
			
		||||
    Split the raise line with format `SomeError some doc`.
 | 
			
		||||
    """
 | 
			
		||||
    splits_on_colon = line.strip().split(" ")
 | 
			
		||||
    error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
 | 
			
		||||
    if error_type and error_type[-1] == ":":
 | 
			
		||||
        error_type = error_type[:-1]
 | 
			
		||||
    return error_type, doc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_arg_line(line):
 | 
			
		||||
    """
 | 
			
		||||
    Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
 | 
			
		||||
    """
 | 
			
		||||
    splits_on_colon = line.split(":")
 | 
			
		||||
    idx = 1
 | 
			
		||||
    while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
 | 
			
		||||
        idx += 2
 | 
			
		||||
    if idx >= len(splits_on_colon):
 | 
			
		||||
        return line, ""
 | 
			
		||||
    return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidRstDocstringError(ValueError):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_re_parameters = re.compile(
 | 
			
		||||
    r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
 | 
			
		||||
)
 | 
			
		||||
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def parse_rst_docstring(docstring):
 | 
			
		||||
    """
 | 
			
		||||
    Parses a docstring written in rst, in particular the list of arguments and the return type.
 | 
			
		||||
    """
 | 
			
		||||
    lines = docstring.split("\n")
 | 
			
		||||
    idx = 0
 | 
			
		||||
    while idx < len(lines):
 | 
			
		||||
        # Parameters section
 | 
			
		||||
        if _re_args.search(lines[idx]) is not None:
 | 
			
		||||
            # Title of the section.
 | 
			
		||||
            lines[idx] = "<parameters>\n"
 | 
			
		||||
            # Find the next nonempty line
 | 
			
		||||
            idx += 1
 | 
			
		||||
            while is_empty_line(lines[idx]):
 | 
			
		||||
                idx += 1
 | 
			
		||||
            # Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
 | 
			
		||||
            # Returns or Raises block.
 | 
			
		||||
            param_indent = find_indent(lines[idx])
 | 
			
		||||
            while (
 | 
			
		||||
                idx < len(lines)
 | 
			
		||||
                and find_indent(lines[idx]) == param_indent
 | 
			
		||||
                and _re_returns.search(lines[idx]) is None
 | 
			
		||||
            ):
 | 
			
		||||
                intro, doc = split_arg_line(lines[idx])
 | 
			
		||||
                # Line starting with a > after indent indicate a "section title" in the parameters.
 | 
			
		||||
                if intro.lstrip().startswith(">"):
 | 
			
		||||
                    lines[idx] = intro.lstrip()
 | 
			
		||||
                else:
 | 
			
		||||
                    lines[idx] = (
 | 
			
		||||
                        re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
 | 
			
		||||
                    )
 | 
			
		||||
                idx += 1
 | 
			
		||||
                while idx < len(lines) and (
 | 
			
		||||
                    is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
 | 
			
		||||
                ):
 | 
			
		||||
                    idx += 1
 | 
			
		||||
            lines.insert(idx, "</parameters>\n")
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
        # Returns section
 | 
			
		||||
        elif _re_returns.search(lines[idx]) is not None:
 | 
			
		||||
            # tag is either `return` or `yield`
 | 
			
		||||
            tag = _re_returns.match(lines[idx]).group(1).lower()
 | 
			
		||||
            # Title of the section.
 | 
			
		||||
            lines[idx] = f"<{tag}s>\n"
 | 
			
		||||
            # Find the next nonempty line
 | 
			
		||||
            idx += 1
 | 
			
		||||
            while is_empty_line(lines[idx]):
 | 
			
		||||
                idx += 1
 | 
			
		||||
 | 
			
		||||
            # Grab the indent of the return line, this block will stop when we unindent under it.
 | 
			
		||||
            return_indent = find_indent(lines[idx])
 | 
			
		||||
            raised_errors = []
 | 
			
		||||
            # The line may contain the return type.
 | 
			
		||||
            if tag in ["return", "yield"]:
 | 
			
		||||
                return_type, return_description = split_return_line(lines[idx])
 | 
			
		||||
                lines[idx] = return_description
 | 
			
		||||
                idx += 1
 | 
			
		||||
                while idx < len(lines) and (
 | 
			
		||||
                    is_empty_line(lines[idx])
 | 
			
		||||
                    or find_indent(lines[idx]) >= return_indent
 | 
			
		||||
                ):
 | 
			
		||||
                    idx += 1
 | 
			
		||||
            else:
 | 
			
		||||
                while idx < len(lines) and find_indent(lines[idx]) == return_indent:
 | 
			
		||||
                    return_type, return_description = split_raise_line(lines[idx])
 | 
			
		||||
                    raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
 | 
			
		||||
                    lines[idx] = "- " + raised_error + " -- " + return_description
 | 
			
		||||
                    md_link = _re_md_link.match(raised_error)
 | 
			
		||||
                    if md_link:
 | 
			
		||||
                        raised_error = md_link[1]
 | 
			
		||||
                        raised_error = re.sub(
 | 
			
		||||
                            r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
 | 
			
		||||
                        )
 | 
			
		||||
                    if raised_error not in raised_errors:
 | 
			
		||||
                        raised_errors.append(raised_error)
 | 
			
		||||
                    idx += 1
 | 
			
		||||
                    while idx < len(lines) and (
 | 
			
		||||
                        is_empty_line(lines[idx])
 | 
			
		||||
                        or find_indent(lines[idx]) > return_indent
 | 
			
		||||
                    ):
 | 
			
		||||
                        idx += 1
 | 
			
		||||
 | 
			
		||||
            lines.insert(idx, f"</{tag}s>\n")
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
            # Return block finished, we insert the return type if one was specified
 | 
			
		||||
            if tag in ["return", "yield"] and return_type is not None:
 | 
			
		||||
                lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
 | 
			
		||||
            elif len(raised_errors) > 0:
 | 
			
		||||
                # raised errors
 | 
			
		||||
                lines[
 | 
			
		||||
                    idx - 1
 | 
			
		||||
                ] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            idx += 1
 | 
			
		||||
 | 
			
		||||
    result = "\n".join(lines)
 | 
			
		||||
 | 
			
		||||
    # combine multiple <parameters> blocks into one block
 | 
			
		||||
    if result.count("<parameters>") > 1:
 | 
			
		||||
        parameters_blocks = _re_parameters.findall(result)
 | 
			
		||||
        parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
 | 
			
		||||
        parameters_str = "\n".join(parameters_blocks)
 | 
			
		||||
        result = _re_parameters.sub("", result)
 | 
			
		||||
        result += f"\n<parameters>{parameters_str}</parameters>\n"
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
 | 
			
		||||
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_indent(text):
 | 
			
		||||
    """
 | 
			
		||||
    Remove indents in text, except the one linked to lists (or sublists).
 | 
			
		||||
    """
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    # List of indents to remember for nested lists
 | 
			
		||||
    current_indents = []
 | 
			
		||||
    # List of new indents to remember for nested lists
 | 
			
		||||
    new_indents = []
 | 
			
		||||
    is_inside_code = False
 | 
			
		||||
    code_indent = 0
 | 
			
		||||
    for idx, line in enumerate(lines):
 | 
			
		||||
        # Line is an item in a list.
 | 
			
		||||
        if _re_list.search(line) is not None:
 | 
			
		||||
            indent = find_indent(line)
 | 
			
		||||
            # Is it a new list / new level of nestedness?
 | 
			
		||||
            if len(current_indents) == 0 or indent > current_indents[-1]:
 | 
			
		||||
                current_indents.append(indent)
 | 
			
		||||
                new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
 | 
			
		||||
                lines[idx] = " " * new_indent + line[indent:]
 | 
			
		||||
                new_indent += len(_re_list.search(line).groups()[0]) + 1
 | 
			
		||||
                new_indents.append(new_indent)
 | 
			
		||||
            # Otherwise it's an existing level of list (current one, or previous one)
 | 
			
		||||
            else:
 | 
			
		||||
                # Let's find the proper level of indentation
 | 
			
		||||
                level = len(current_indents) - 1
 | 
			
		||||
                while level >= 0 and current_indents[level] != indent:
 | 
			
		||||
                    level -= 1
 | 
			
		||||
                current_indents = current_indents[: level + 1]
 | 
			
		||||
                new_indents = new_indents[:level]
 | 
			
		||||
                new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
 | 
			
		||||
                lines[idx] = " " * new_indent + line[indent:]
 | 
			
		||||
                new_indent += len(_re_list.search(line).groups()[0]) + 1
 | 
			
		||||
                new_indents.append(new_indent)
 | 
			
		||||
 | 
			
		||||
        # Line is an autodoc, we keep the indent for the list just after if there is one.
 | 
			
		||||
        elif _re_autodoc.search(line) is not None:
 | 
			
		||||
            indent = find_indent(line)
 | 
			
		||||
            current_indents = [indent]
 | 
			
		||||
            new_indents = [4]
 | 
			
		||||
            lines[idx] = line.strip()
 | 
			
		||||
 | 
			
		||||
        # Deal with empty lines separately
 | 
			
		||||
        elif is_empty_line(line):
 | 
			
		||||
            lines[idx] = ""
 | 
			
		||||
 | 
			
		||||
        # Code blocks
 | 
			
		||||
        elif line.lstrip().startswith("```"):
 | 
			
		||||
            is_inside_code = not is_inside_code
 | 
			
		||||
            if is_inside_code:
 | 
			
		||||
                code_indent = find_indent(line)
 | 
			
		||||
            lines[idx] = line[code_indent:]
 | 
			
		||||
        elif is_inside_code:
 | 
			
		||||
            lines[idx] = line[code_indent:]
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            indent = find_indent(line)
 | 
			
		||||
            if len(current_indents) > 0 and indent > current_indents[-1]:
 | 
			
		||||
                lines[idx] = " " * new_indents[-1] + line[indent:]
 | 
			
		||||
            elif len(current_indents) > 0:
 | 
			
		||||
                # Let's find the proper level of indentation
 | 
			
		||||
                level = len(current_indents) - 1
 | 
			
		||||
                while level >= 0 and current_indents[level] > indent:
 | 
			
		||||
                    level -= 1
 | 
			
		||||
                current_indents = current_indents[: level + 1]
 | 
			
		||||
                if level >= 0:
 | 
			
		||||
                    if current_indents[level] < indent:
 | 
			
		||||
                        new_indents = new_indents[: level + 1]
 | 
			
		||||
                    else:
 | 
			
		||||
                        new_indents = new_indents[:level]
 | 
			
		||||
                    new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
 | 
			
		||||
                    lines[idx] = " " * new_indent + line[indent:]
 | 
			
		||||
                    new_indents.append(new_indent)
 | 
			
		||||
                else:
 | 
			
		||||
                    new_indents = []
 | 
			
		||||
                    lines[idx] = line[indent:]
 | 
			
		||||
            else:
 | 
			
		||||
                lines[idx] = line[indent:]
 | 
			
		||||
 | 
			
		||||
    return "\n".join(lines)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def base_rst_to_mdx(text, page_info, unindent=True):
 | 
			
		||||
    """
 | 
			
		||||
    Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
 | 
			
		||||
    """
 | 
			
		||||
    text = convert_rst_links(text, page_info)
 | 
			
		||||
    text = convert_special_chars(text)
 | 
			
		||||
    text = convert_rst_blocks(text, page_info)
 | 
			
		||||
    # Convert * in lists to - to avoid the formatting conversion treat them as bold.
 | 
			
		||||
    text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
 | 
			
		||||
    text = convert_rst_formatting(text)
 | 
			
		||||
    return remove_indent(text) if unindent else text
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_rst_docstring_to_mdx(docstring, page_info):
 | 
			
		||||
    """
 | 
			
		||||
    Convert a docstring written in rst to mdx.
 | 
			
		||||
    """
 | 
			
		||||
    text = parse_rst_docstring(docstring)
 | 
			
		||||
    return base_rst_to_mdx(text, page_info)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def process_titles(lines):
 | 
			
		||||
    """Converts rst titles to markdown titles."""
 | 
			
		||||
    title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
 | 
			
		||||
    title_levels = {}
 | 
			
		||||
    new_lines = []
 | 
			
		||||
    for line in lines:
 | 
			
		||||
        if (
 | 
			
		||||
            len(new_lines) > 0
 | 
			
		||||
            and len(line) >= len(new_lines[-1])
 | 
			
		||||
            and len(set(line)) == 1
 | 
			
		||||
            and line[0] in title_chars
 | 
			
		||||
            and line != "::"
 | 
			
		||||
        ):
 | 
			
		||||
            char = line[0]
 | 
			
		||||
            level = title_levels.get(char, len(title_levels) + 1)
 | 
			
		||||
            if level not in title_levels:
 | 
			
		||||
                title_levels[char] = level
 | 
			
		||||
            new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
 | 
			
		||||
        else:
 | 
			
		||||
            new_lines.append(line)
 | 
			
		||||
    return new_lines
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Matches lines with a pattern of a table new line in rst.
 | 
			
		||||
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
 | 
			
		||||
# Matches lines with a pattern of a table new line in rst, with a first column empty.
 | 
			
		||||
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
 | 
			
		||||
# Matches lines with a pattern of a first table line in rst.
 | 
			
		||||
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
 | 
			
		||||
# Re pattern that catches anchors of the type .. reference:
 | 
			
		||||
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_pt_tf_code_blocks(text):
 | 
			
		||||
    """
 | 
			
		||||
    Split PyTorch and TensorFlow specific block codes.
 | 
			
		||||
    """
 | 
			
		||||
    lines = text.split("\n")
 | 
			
		||||
    new_lines = []
 | 
			
		||||
    idx = 0
 | 
			
		||||
    while idx < len(lines):
 | 
			
		||||
        if lines[idx].startswith("```"):
 | 
			
		||||
            code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
 | 
			
		||||
            is_pytorch = False
 | 
			
		||||
            is_tensorflow = False
 | 
			
		||||
            idx += 1
 | 
			
		||||
            while idx < len(lines) and lines[idx].strip() != "```":
 | 
			
		||||
                if "## PYTORCH CODE" in lines[idx]:
 | 
			
		||||
                    is_pytorch = True
 | 
			
		||||
                    is_tensorflow = False
 | 
			
		||||
                elif "## TENSORFLOW CODE" in lines[idx]:
 | 
			
		||||
                    is_tensorflow = True
 | 
			
		||||
                    is_pytorch = False
 | 
			
		||||
                elif is_pytorch:
 | 
			
		||||
                    code_lines["pytorch"].append(lines[idx])
 | 
			
		||||
                elif is_tensorflow:
 | 
			
		||||
                    code_lines["tensorflow"].append(lines[idx])
 | 
			
		||||
                else:
 | 
			
		||||
                    code_lines["common"].append(lines[idx])
 | 
			
		||||
                idx += 1
 | 
			
		||||
            if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
 | 
			
		||||
                block_lines = ["<frameworkcontent>", "<pt>"]
 | 
			
		||||
                block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
 | 
			
		||||
                block_lines.extend(["```", "</pt>", "<tf>"])
 | 
			
		||||
                block_lines.extend(
 | 
			
		||||
                    code_lines["common"].copy() + code_lines["tensorflow"]
 | 
			
		||||
                )
 | 
			
		||||
                block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
 | 
			
		||||
                new_lines.extend(block_lines)
 | 
			
		||||
            else:
 | 
			
		||||
                block_lines = code_lines["common"] + ["```"]
 | 
			
		||||
                new_lines.extend(block_lines)
 | 
			
		||||
            idx += 1
 | 
			
		||||
        else:
 | 
			
		||||
            new_lines.append(lines[idx])
 | 
			
		||||
            idx += 1
 | 
			
		||||
    return "\n".join(new_lines)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
 | 
			
		||||
    """
 | 
			
		||||
    Convert a document written in rst to mdx.
 | 
			
		||||
    """
 | 
			
		||||
    lines = rst_text.split("\n")
 | 
			
		||||
    lines = process_titles(lines)
 | 
			
		||||
    if add_imports:
 | 
			
		||||
        new_lines = [
 | 
			
		||||
            '<script lang="ts">',
 | 
			
		||||
            '	import Tip from "$lib/Tip.svelte";',
 | 
			
		||||
            '	import Youtube from "$lib/Youtube.svelte";',
 | 
			
		||||
            '	import Docstring from "$lib/Docstring.svelte";',
 | 
			
		||||
            '	import CodeBlock from "$lib/CodeBlock.svelte";',
 | 
			
		||||
            '	import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
 | 
			
		||||
            '	import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
 | 
			
		||||
            '	import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
 | 
			
		||||
            '	import IconCopyLink from "$lib/IconCopyLink.svelte";',
 | 
			
		||||
            '	import FrameworkContent from "$lib/FrameworkContent.svelte";',
 | 
			
		||||
            '	import Markdown from "$lib/Markdown.svelte";',
 | 
			
		||||
            '	import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
 | 
			
		||||
            '	import Added from "$lib/Added.svelte";',
 | 
			
		||||
            '	import Changed from "$lib/Changed.svelte";',
 | 
			
		||||
            '	import Deprecated from "$lib/Deprecated.svelte";',
 | 
			
		||||
            '	import PipelineIcon from "$lib/PipelineIcon.svelte";',
 | 
			
		||||
            '	import PipelineTag from "$lib/PipelineTag.svelte";',
 | 
			
		||||
            "	",
 | 
			
		||||
            '	export let fw: "pt" | "tf"',
 | 
			
		||||
            "</script>",
 | 
			
		||||
            "<svelte:head>",
 | 
			
		||||
            '<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
 | 
			
		||||
            "</svelte:head>",
 | 
			
		||||
            "",
 | 
			
		||||
        ]
 | 
			
		||||
    else:
 | 
			
		||||
        new_lines = []
 | 
			
		||||
    for line in lines:
 | 
			
		||||
        if _re_ignore_line_table.search(line) is not None:
 | 
			
		||||
            continue
 | 
			
		||||
        elif _re_ignore_line_table1.search(line) is not None:
 | 
			
		||||
            continue
 | 
			
		||||
        elif _re_sep_line_table.search(line) is not None:
 | 
			
		||||
            line = line.replace("=", "-").replace("+", "|")
 | 
			
		||||
        elif _re_anchor_section.search(line) is not None:
 | 
			
		||||
            anchor_name = _re_anchor_section.search(line).groups()[0]
 | 
			
		||||
            line = f"<a id='{anchor_name}'></a>"
 | 
			
		||||
        new_lines.append(line)
 | 
			
		||||
    text = "\n".join(new_lines)
 | 
			
		||||
 | 
			
		||||
    return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))
 | 
			
		||||
							
								
								
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								src/kernels/_versions.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
			
		||||
from typing import Dict, Optional
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import HfApi
 | 
			
		||||
from huggingface_hub.hf_api import GitRefInfo
 | 
			
		||||
from packaging.specifiers import SpecifierSet
 | 
			
		||||
from packaging.version import InvalidVersion, Version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
 | 
			
		||||
    """Get kernel versions that are available in the repository."""
 | 
			
		||||
    versions = {}
 | 
			
		||||
    for tag in HfApi().list_repo_refs(repo_id).tags:
 | 
			
		||||
        if not tag.name.startswith("v"):
 | 
			
		||||
            continue
 | 
			
		||||
        try:
 | 
			
		||||
            versions[Version(tag.name[1:])] = tag
 | 
			
		||||
        except InvalidVersion:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
    return versions
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
 | 
			
		||||
    """
 | 
			
		||||
    Get the locks for a kernel with the given version spec.
 | 
			
		||||
 | 
			
		||||
    The version specifier can be any valid Python version specifier:
 | 
			
		||||
    https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
 | 
			
		||||
    """
 | 
			
		||||
    versions = _get_available_versions(repo_id)
 | 
			
		||||
    requirement = SpecifierSet(version_spec)
 | 
			
		||||
    accepted_versions = sorted(requirement.filter(versions.keys()))
 | 
			
		||||
 | 
			
		||||
    if len(accepted_versions) == 0:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"No version of `{repo_id}` satisfies requirement: {version_spec}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return versions[accepted_versions[-1]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def select_revision_or_version(
 | 
			
		||||
    repo_id: str, revision: Optional[str], version: Optional[str]
 | 
			
		||||
) -> str:
 | 
			
		||||
    if revision is not None and version is not None:
 | 
			
		||||
        raise ValueError("Either a revision or a version must be specified, not both.")
 | 
			
		||||
    elif revision is None and version is None:
 | 
			
		||||
        revision = "main"
 | 
			
		||||
    elif version is not None:
 | 
			
		||||
        revision = resolve_version_spec_as_ref(repo_id, version).target_commit
 | 
			
		||||
    assert revision is not None
 | 
			
		||||
    return revision
 | 
			
		||||
@ -4,10 +4,15 @@ import json
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import create_repo, upload_folder
 | 
			
		||||
 | 
			
		||||
from kernels.compat import tomllib
 | 
			
		||||
from kernels.lockfile import KernelLock, get_kernel_locks
 | 
			
		||||
from kernels.utils import install_kernel, install_kernel_all_variants
 | 
			
		||||
 | 
			
		||||
from .doc import generate_readme_for_kernel
 | 
			
		||||
from .wheel import build_variant_to_wheel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
@ -28,6 +33,24 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
    download_parser.set_defaults(func=download_kernels)
 | 
			
		||||
 | 
			
		||||
    upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "kernel_dir",
 | 
			
		||||
        type=Path,
 | 
			
		||||
        help="Directory of the kernel build",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "--repo_id",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="Repository ID to use to upload to the Hugging Face Hub",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.add_argument(
 | 
			
		||||
        "--private",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="If the repository should be private.",
 | 
			
		||||
    )
 | 
			
		||||
    upload_parser.set_defaults(func=upload_kernels)
 | 
			
		||||
 | 
			
		||||
    lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
 | 
			
		||||
    lock_parser.add_argument(
 | 
			
		||||
        "project_dir",
 | 
			
		||||
@ -36,6 +59,47 @@ def main():
 | 
			
		||||
    )
 | 
			
		||||
    lock_parser.set_defaults(func=lock_kernels)
 | 
			
		||||
 | 
			
		||||
    to_wheel_parser = subparsers.add_parser(
 | 
			
		||||
        "to-wheel", help="Convert a kernel to a wheel file"
 | 
			
		||||
    )
 | 
			
		||||
    to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
 | 
			
		||||
    to_wheel_parser.add_argument("version", type=str, help="The kernel version")
 | 
			
		||||
    to_wheel_parser.add_argument(
 | 
			
		||||
        "--python-version",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="3.9",
 | 
			
		||||
        help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
 | 
			
		||||
    )
 | 
			
		||||
    to_wheel_parser.add_argument(
 | 
			
		||||
        "--manylinux-version",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="2.28",
 | 
			
		||||
        help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
 | 
			
		||||
    )
 | 
			
		||||
    to_wheel_parser.set_defaults(func=kernels_to_wheel)
 | 
			
		||||
 | 
			
		||||
    # Add generate-readme subcommand parser
 | 
			
		||||
    generate_readme_parser = subparsers.add_parser(
 | 
			
		||||
        "generate-readme",
 | 
			
		||||
        help="Generate README snippets for a kernel's public functions",
 | 
			
		||||
    )
 | 
			
		||||
    generate_readme_parser.add_argument(
 | 
			
		||||
        "repo_id",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help="The kernel repo ID (e.g., kernels-community/activation)",
 | 
			
		||||
    )
 | 
			
		||||
    generate_readme_parser.add_argument(
 | 
			
		||||
        "--revision",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="main",
 | 
			
		||||
        help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
 | 
			
		||||
    )
 | 
			
		||||
    generate_readme_parser.set_defaults(
 | 
			
		||||
        func=lambda args: generate_readme_for_kernel(
 | 
			
		||||
            repo_id=args.repo_id, revision=args.revision
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    args.func(args)
 | 
			
		||||
 | 
			
		||||
@ -77,6 +141,24 @@ def download_kernels(args):
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def kernels_to_wheel(args):
 | 
			
		||||
    variants_path = install_kernel_all_variants(
 | 
			
		||||
        repo_id=args.repo_id, revision=f"v{args.version}"
 | 
			
		||||
    )
 | 
			
		||||
    for variant_path in variants_path.iterdir():
 | 
			
		||||
        if not variant_path.is_dir():
 | 
			
		||||
            continue
 | 
			
		||||
        wheel_path = build_variant_to_wheel(
 | 
			
		||||
            manylinux_version=args.manylinux_version,
 | 
			
		||||
            python_version=args.python_version,
 | 
			
		||||
            repo_id=args.repo_id,
 | 
			
		||||
            version=args.version,
 | 
			
		||||
            variant_path=variant_path,
 | 
			
		||||
            wheel_dir=Path("."),
 | 
			
		||||
        )
 | 
			
		||||
        print(f"☸️ {wheel_path.name}", file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lock_kernels(args):
 | 
			
		||||
    with open(args.project_dir / "pyproject.toml", "rb") as f:
 | 
			
		||||
        data = tomllib.load(f)
 | 
			
		||||
@ -91,6 +173,33 @@ def lock_kernels(args):
 | 
			
		||||
        json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upload_kernels(args):
 | 
			
		||||
    kernel_dir = Path(args.kernel_dir).resolve()
 | 
			
		||||
    build_dir = kernel_dir / "build"
 | 
			
		||||
    if not kernel_dir.is_dir():
 | 
			
		||||
        raise ValueError(f"{kernel_dir} is not a directory")
 | 
			
		||||
    if not build_dir.is_dir():
 | 
			
		||||
        raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
 | 
			
		||||
 | 
			
		||||
    repo_id = create_repo(
 | 
			
		||||
        repo_id=args.repo_id, private=args.private, exist_ok=True
 | 
			
		||||
    ).repo_id
 | 
			
		||||
 | 
			
		||||
    delete_patterns: set[str] = set()
 | 
			
		||||
    for build_variant in build_dir.iterdir():
 | 
			
		||||
        if build_variant.is_dir():
 | 
			
		||||
            delete_patterns.add(f"{build_variant.name}/**")
 | 
			
		||||
 | 
			
		||||
    upload_folder(
 | 
			
		||||
        repo_id=repo_id,
 | 
			
		||||
        folder_path=build_dir,
 | 
			
		||||
        path_in_repo="build",
 | 
			
		||||
        delete_patterns=list(delete_patterns),
 | 
			
		||||
        commit_message="Build uploaded using `kernels`.",
 | 
			
		||||
    )
 | 
			
		||||
    print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _JSONEncoder(json.JSONEncoder):
 | 
			
		||||
    def default(self, o):
 | 
			
		||||
        if dataclasses.is_dataclass(o):
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										242
									
								
								src/kernels/doc.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,242 @@
 | 
			
		||||
import inspect
 | 
			
		||||
import re
 | 
			
		||||
import sys
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
 | 
			
		||||
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
 | 
			
		||||
from .utils import get_kernel
 | 
			
		||||
 | 
			
		||||
_RE_PARAMETERS = re.compile(
 | 
			
		||||
    r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
 | 
			
		||||
)
 | 
			
		||||
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
 | 
			
		||||
_RE_RETURNTYPE = re.compile(
 | 
			
		||||
    r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extract_description_before_tags(docstring_mdx: str) -> str:
 | 
			
		||||
    """Extract the description part of a docstring before any tags."""
 | 
			
		||||
    params_pos = docstring_mdx.find("<parameters>")
 | 
			
		||||
    returns_pos = docstring_mdx.find("<returns>")
 | 
			
		||||
    returntype_pos = docstring_mdx.find("<returntype>")
 | 
			
		||||
    positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
 | 
			
		||||
 | 
			
		||||
    if positions:
 | 
			
		||||
        first_tag_pos = min(positions)
 | 
			
		||||
        return docstring_mdx[:first_tag_pos].strip()
 | 
			
		||||
    else:
 | 
			
		||||
        return docstring_mdx.strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
 | 
			
		||||
    """Print the parameters section from a docstring."""
 | 
			
		||||
    matches = _RE_PARAMETERS.findall(docstring_mdx)
 | 
			
		||||
    if matches:
 | 
			
		||||
        header = "#" * header_level
 | 
			
		||||
        print(f"\n{header} Parameters")
 | 
			
		||||
        for match in matches:
 | 
			
		||||
            print(f"\n{match[0].strip()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _print_returns_section(
 | 
			
		||||
    docstring_mdx: str, *, context_name: str, header_level: int
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Print the returns section from a docstring."""
 | 
			
		||||
    return_matches = _RE_RETURNS.findall(docstring_mdx)
 | 
			
		||||
    returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
 | 
			
		||||
 | 
			
		||||
    if return_matches or returntype_matches:
 | 
			
		||||
        header = "#" * header_level
 | 
			
		||||
        print(f"\n{header} Returns")
 | 
			
		||||
 | 
			
		||||
        if returntype_matches:
 | 
			
		||||
            if len(returntype_matches) > 1:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    f"More than one <returntype> tag found in docstring for {context_name}"
 | 
			
		||||
                )
 | 
			
		||||
            print(f"\n**Type**: {returntype_matches[0][0].strip()}")
 | 
			
		||||
 | 
			
		||||
        if return_matches:
 | 
			
		||||
            for match in return_matches:
 | 
			
		||||
                print(f"\n{match[0].strip()}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_docstring(obj, use_dict_check: bool = False) -> str:
 | 
			
		||||
    """Get docstring from an object, with fallback to default message."""
 | 
			
		||||
    # Check whether the class/method itself has docs and not just
 | 
			
		||||
    # the superclass.
 | 
			
		||||
    if use_dict_check:
 | 
			
		||||
        has_doc = obj.__dict__.get("__doc__", None) is not None
 | 
			
		||||
    else:
 | 
			
		||||
        has_doc = getattr(obj, "__doc__", None) is not None
 | 
			
		||||
 | 
			
		||||
    # We use inspect.getdoc because it does normalization.
 | 
			
		||||
    doc = inspect.getdoc(obj)
 | 
			
		||||
 | 
			
		||||
    return doc if has_doc and doc is not None else "No documentation available."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _process_and_print_docstring(
 | 
			
		||||
    docstring: str, *, kernel_name: str, context_name: str, header_level: int
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Convert docstring to MDX and print description, parameters, and returns sections."""
 | 
			
		||||
    docstring_mdx = convert_rst_docstring_to_mdx(
 | 
			
		||||
        docstring, page_info={"package_name": kernel_name}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Print the description
 | 
			
		||||
    description = _extract_description_before_tags(docstring_mdx)
 | 
			
		||||
    print(f"\n{description}")
 | 
			
		||||
 | 
			
		||||
    # Print parameters and returns sections
 | 
			
		||||
    _print_parameters_section(docstring_mdx, header_level=header_level)
 | 
			
		||||
    _print_returns_section(
 | 
			
		||||
        docstring_mdx, context_name=context_name, header_level=header_level
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
 | 
			
		||||
    kernel_module = get_kernel(repo_id=repo_id, revision=revision)
 | 
			
		||||
    kernel_name = repo_id.split("/")[-1].replace("-", "_")
 | 
			
		||||
 | 
			
		||||
    generate_metadata(kernel_module)
 | 
			
		||||
    generate_kernel_doc(kernel_module, kernel_name)
 | 
			
		||||
    generate_function_doc(kernel_module, kernel_name)
 | 
			
		||||
    generate_layers_doc(kernel_module, kernel_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_metadata(module: ModuleType) -> None:
 | 
			
		||||
    metadata = getattr(module, "__kernel_metadata__", {})
 | 
			
		||||
    if "tags" not in metadata:
 | 
			
		||||
        metadata["tags"] = ["kernel"]
 | 
			
		||||
    else:
 | 
			
		||||
        if "kernel" not in metadata["tags"]:
 | 
			
		||||
            metadata["tags"].append("kernel")
 | 
			
		||||
 | 
			
		||||
    print("---")
 | 
			
		||||
    print(yaml.dump(metadata), end="")
 | 
			
		||||
    print("---")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
 | 
			
		||||
    docstring = module.__doc__.strip() if module.__doc__ is not None else None
 | 
			
		||||
    if docstring:
 | 
			
		||||
        title, rest = docstring.split("\n", 1)
 | 
			
		||||
        print(f"# {title.strip()}")
 | 
			
		||||
        print(
 | 
			
		||||
            f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
 | 
			
		||||
    print("\n## Functions")
 | 
			
		||||
 | 
			
		||||
    # Track if we found any functions
 | 
			
		||||
    found_functions = False
 | 
			
		||||
 | 
			
		||||
    for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
 | 
			
		||||
        # Do not include imported functions.
 | 
			
		||||
        if func.__module__ != kernel_module.__name__:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Exclude private functions.
 | 
			
		||||
        if name.startswith("_"):
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        found_functions = True
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            sig = inspect.signature(func)
 | 
			
		||||
            docstring = _get_docstring(func)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            print(
 | 
			
		||||
                f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
 | 
			
		||||
                file=sys.stderr,
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        print(f"\n### Function `{name}`")
 | 
			
		||||
        print(f"\n`{sig}`")
 | 
			
		||||
 | 
			
		||||
        _process_and_print_docstring(
 | 
			
		||||
            docstring, kernel_name=kernel_name, context_name=name, header_level=3
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if not found_functions:
 | 
			
		||||
        print("\nNo public top-level functions.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
 | 
			
		||||
    # Check if layers module is available
 | 
			
		||||
    layers_module = getattr(kernel_module, "layers", None)
 | 
			
		||||
    if layers_module is None:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    print("\n## Layers")
 | 
			
		||||
 | 
			
		||||
    # Track if we found any classes
 | 
			
		||||
    found_classes = False
 | 
			
		||||
 | 
			
		||||
    for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
 | 
			
		||||
        # Exclude classes that were imported.
 | 
			
		||||
        if cls.__module__ != layers_module.__name__:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        found_classes = True
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            # Get docstring, but not from superclasses.
 | 
			
		||||
            class_docstring = _get_docstring(cls, use_dict_check=True)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            print(
 | 
			
		||||
                f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
 | 
			
		||||
                file=sys.stderr,
 | 
			
		||||
            )
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        print(f"\n### Class `{class_name}`")
 | 
			
		||||
 | 
			
		||||
        # Always print class description (helper handles conversion and formatting)
 | 
			
		||||
        class_docstring_mdx = convert_rst_docstring_to_mdx(
 | 
			
		||||
            class_docstring, page_info={"package_name": kernel_name}
 | 
			
		||||
        )
 | 
			
		||||
        description = _extract_description_before_tags(class_docstring_mdx)
 | 
			
		||||
        print(f"\n{description}")
 | 
			
		||||
 | 
			
		||||
        # Document methods
 | 
			
		||||
        print("\n#### Methods")
 | 
			
		||||
 | 
			
		||||
        for method_name, method in inspect.getmembers(cls, inspect.isfunction):
 | 
			
		||||
            # Note: also skip __init__, since extension layers cannot have a constructor.
 | 
			
		||||
            if method_name.startswith("_"):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            # Skip methods from superclasses.
 | 
			
		||||
            if method_name not in cls.__dict__:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                sig = inspect.signature(method)
 | 
			
		||||
                method_docstring = _get_docstring(method)
 | 
			
		||||
            except ValueError:
 | 
			
		||||
                print(
 | 
			
		||||
                    f"Warning: Could not retrieve signature for {method_name} in {class_name}",
 | 
			
		||||
                    file=sys.stderr,
 | 
			
		||||
                )
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            print(f"\n##### Method `{method_name}`")
 | 
			
		||||
            print(f"\n`{sig}`")
 | 
			
		||||
 | 
			
		||||
            _process_and_print_docstring(
 | 
			
		||||
                method_docstring,
 | 
			
		||||
                kernel_name=kernel_name,
 | 
			
		||||
                context_name=method_name,
 | 
			
		||||
                header_level=6,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if not found_classes:
 | 
			
		||||
        print("\nNo layers defined.")
 | 
			
		||||
							
								
								
									
										1180
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
							
						
						
									
										1180
									
								
								src/kernels/layer.py
									
									
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -4,10 +4,8 @@ from pathlib import Path
 | 
			
		||||
from typing import Dict, List, Tuple
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import HfApi
 | 
			
		||||
from huggingface_hub.hf_api import GitRefInfo
 | 
			
		||||
from packaging.specifiers import SpecifierSet
 | 
			
		||||
from packaging.version import InvalidVersion, Version
 | 
			
		||||
 | 
			
		||||
from kernels._versions import resolve_version_spec_as_ref
 | 
			
		||||
from kernels.compat import tomllib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,20 +29,6 @@ class KernelLock:
 | 
			
		||||
        return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
 | 
			
		||||
    """Get kernel versions that are available in the repository."""
 | 
			
		||||
    versions = {}
 | 
			
		||||
    for tag in HfApi().list_repo_refs(repo_id).tags:
 | 
			
		||||
        if not tag.name.startswith("v"):
 | 
			
		||||
            continue
 | 
			
		||||
        try:
 | 
			
		||||
            versions[Version(tag.name[1:])] = tag
 | 
			
		||||
        except InvalidVersion:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
    return versions
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
 | 
			
		||||
    """
 | 
			
		||||
    Get the locks for a kernel with the given version spec.
 | 
			
		||||
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
 | 
			
		||||
    The version specifier can be any valid Python version specifier:
 | 
			
		||||
    https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
 | 
			
		||||
    """
 | 
			
		||||
    versions = _get_available_versions(repo_id)
 | 
			
		||||
    requirement = SpecifierSet(version_spec)
 | 
			
		||||
    accepted_versions = sorted(requirement.filter(versions.keys()))
 | 
			
		||||
 | 
			
		||||
    if len(accepted_versions) == 0:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"No version of `{repo_id}` satisfies requirement: {version_spec}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    tag_for_newest = versions[accepted_versions[-1]]
 | 
			
		||||
    tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
 | 
			
		||||
 | 
			
		||||
    r = HfApi().repo_info(
 | 
			
		||||
        repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@ import importlib
 | 
			
		||||
import importlib.metadata
 | 
			
		||||
import inspect
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import platform
 | 
			
		||||
import sys
 | 
			
		||||
@ -12,29 +13,58 @@ from pathlib import Path
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
from typing import Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
from huggingface_hub import file_exists, snapshot_download
 | 
			
		||||
from packaging.version import parse
 | 
			
		||||
 | 
			
		||||
from kernels._versions import select_revision_or_version
 | 
			
		||||
from kernels.lockfile import KernelLock, VariantLock
 | 
			
		||||
 | 
			
		||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
 | 
			
		||||
 | 
			
		||||
def _get_cache_dir() -> Optional[str]:
 | 
			
		||||
    """Returns the kernels cache directory."""
 | 
			
		||||
    cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
 | 
			
		||||
    if cache_dir is not None:
 | 
			
		||||
        logging.warning(
 | 
			
		||||
            "HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
 | 
			
		||||
        )
 | 
			
		||||
        return cache_dir
 | 
			
		||||
 | 
			
		||||
    return os.environ.get("KERNELS_CACHE", None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CACHE_DIR: Optional[str] = _get_cache_dir()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_variant() -> str:
 | 
			
		||||
    import torch
 | 
			
		||||
 | 
			
		||||
    if torch.version.cuda is None:
 | 
			
		||||
    if torch.version.cuda is not None:
 | 
			
		||||
        cuda_version = parse(torch.version.cuda)
 | 
			
		||||
        compute_framework = f"cu{cuda_version.major}{cuda_version.minor}"
 | 
			
		||||
    elif torch.version.hip is not None:
 | 
			
		||||
        rocm_version = parse(torch.version.hip.split("-")[0])
 | 
			
		||||
        compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
 | 
			
		||||
    elif torch.backends.mps.is_available():
 | 
			
		||||
        compute_framework = "metal"
 | 
			
		||||
    elif torch.version.xpu is not None:
 | 
			
		||||
        version = torch.version.xpu
 | 
			
		||||
        compute_framework = f"xpu{version[0:4]}{version[5:6]}"
 | 
			
		||||
    else:
 | 
			
		||||
        raise AssertionError(
 | 
			
		||||
            "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
 | 
			
		||||
            "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    torch_version = parse(torch.__version__)
 | 
			
		||||
    cuda_version = parse(torch.version.cuda)
 | 
			
		||||
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
 | 
			
		||||
    cpu = platform.machine()
 | 
			
		||||
    os = platform.system().lower()
 | 
			
		||||
 | 
			
		||||
    return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
 | 
			
		||||
    if os == "darwin":
 | 
			
		||||
        cpu = "aarch64" if cpu == "arm64" else cpu
 | 
			
		||||
        return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
 | 
			
		||||
 | 
			
		||||
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
 | 
			
		||||
 | 
			
		||||
    return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def universal_build_variant() -> str:
 | 
			
		||||
@ -69,7 +99,20 @@ def install_kernel(
 | 
			
		||||
    """
 | 
			
		||||
    Download a kernel for the current environment to the cache.
 | 
			
		||||
 | 
			
		||||
    The output path is validated againt `hash` when set.
 | 
			
		||||
    The output path is validated against the hashes in `variant_locks` when provided.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_id (`str`):
 | 
			
		||||
            The Hub repository containing the kernel.
 | 
			
		||||
        revision (`str`):
 | 
			
		||||
            The specific revision (branch, tag, or commit) to download.
 | 
			
		||||
        local_files_only (`bool`, *optional*, defaults to `False`):
 | 
			
		||||
            Whether to only use local files and not download from the Hub.
 | 
			
		||||
        variant_locks (`Dict[str, VariantLock]`, *optional*):
 | 
			
		||||
            Optional dictionary of variant locks for validation.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
 | 
			
		||||
    """
 | 
			
		||||
    package_name = package_name_from_repo_id(repo_id)
 | 
			
		||||
    variant = build_variant()
 | 
			
		||||
@ -84,6 +127,23 @@ def install_kernel(
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        return _load_kernel_from_path(repo_path, package_name, variant_locks)
 | 
			
		||||
    except FileNotFoundError:
 | 
			
		||||
        # Redo with more specific error message.
 | 
			
		||||
        raise FileNotFoundError(
 | 
			
		||||
            f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_kernel_from_path(
 | 
			
		||||
    repo_path: Path,
 | 
			
		||||
    package_name: str,
 | 
			
		||||
    variant_locks: Optional[Dict[str, VariantLock]] = None,
 | 
			
		||||
) -> Tuple[str, Path]:
 | 
			
		||||
    variant = build_variant()
 | 
			
		||||
    universal_variant = universal_build_variant()
 | 
			
		||||
 | 
			
		||||
    variant_path = repo_path / "build" / variant
 | 
			
		||||
    universal_variant_path = repo_path / "build" / universal_variant
 | 
			
		||||
 | 
			
		||||
@ -102,7 +162,7 @@ def install_kernel(
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(module_init_path):
 | 
			
		||||
        raise FileNotFoundError(
 | 
			
		||||
            f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
 | 
			
		||||
            f"Kernel at path `{repo_path}` does not have build: {variant}"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return package_name, variant_path
 | 
			
		||||
@ -139,17 +199,128 @@ def install_kernel_all_variants(
 | 
			
		||||
    return repo_path / "build"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
 | 
			
		||||
def get_kernel(
 | 
			
		||||
    repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
 | 
			
		||||
) -> ModuleType:
 | 
			
		||||
    """
 | 
			
		||||
    Load a kernel from the kernel hub.
 | 
			
		||||
 | 
			
		||||
    This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
 | 
			
		||||
    and then loads the kernel.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_id (`str`):
 | 
			
		||||
            The Hub repository containing the kernel.
 | 
			
		||||
        revision (`str`, *optional*, defaults to `"main"`):
 | 
			
		||||
            The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
 | 
			
		||||
        version (`str`, *optional*):
 | 
			
		||||
            The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
 | 
			
		||||
            Cannot be used together with `revision`.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `ModuleType`: The imported kernel module.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
        ```python
 | 
			
		||||
        import torch
 | 
			
		||||
        from kernels import get_kernel
 | 
			
		||||
 | 
			
		||||
        activation = get_kernel("kernels-community/activation")
 | 
			
		||||
        x = torch.randn(10, 20, device="cuda")
 | 
			
		||||
        out = torch.empty_like(x)
 | 
			
		||||
        result = activation.silu_and_mul(out, x)
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
    revision = select_revision_or_version(repo_id, revision, version)
 | 
			
		||||
    package_name, package_path = install_kernel(repo_id, revision=revision)
 | 
			
		||||
    return import_from_path(package_name, package_path / package_name / "__init__.py")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
 | 
			
		||||
    """
 | 
			
		||||
    Import a kernel from a local kernel repository path.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_path (`Path`):
 | 
			
		||||
            The local path to the kernel repository.
 | 
			
		||||
        package_name (`str`):
 | 
			
		||||
            The name of the package to import from the repository.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `ModuleType`: The imported kernel module.
 | 
			
		||||
    """
 | 
			
		||||
    variant = build_variant()
 | 
			
		||||
    universal_variant = universal_build_variant()
 | 
			
		||||
 | 
			
		||||
    # Presume we were given the top level path of the kernel repository.
 | 
			
		||||
    for base_path in [repo_path, repo_path / "build"]:
 | 
			
		||||
        # Prefer the universal variant if it exists.
 | 
			
		||||
        for v in [universal_variant, variant]:
 | 
			
		||||
            package_path = base_path / v / package_name / "__init__.py"
 | 
			
		||||
            if package_path.exists():
 | 
			
		||||
                return import_from_path(package_name, package_path)
 | 
			
		||||
 | 
			
		||||
    # If we didn't find the package in the repo we may have a explicit
 | 
			
		||||
    # package path.
 | 
			
		||||
    package_path = repo_path / package_name / "__init__.py"
 | 
			
		||||
    if package_path.exists():
 | 
			
		||||
        return import_from_path(package_name, package_path)
 | 
			
		||||
 | 
			
		||||
    raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def has_kernel(
 | 
			
		||||
    repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Check whether a kernel build exists for the current environment (Torch version and compute framework).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_id (`str`):
 | 
			
		||||
            The Hub repository containing the kernel.
 | 
			
		||||
        revision (`str`, *optional*, defaults to `"main"`):
 | 
			
		||||
            The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
 | 
			
		||||
        version (`str`, *optional*):
 | 
			
		||||
            The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
 | 
			
		||||
            Cannot be used together with `revision`.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `bool`: `True` if a kernel is available for the current environment.
 | 
			
		||||
    """
 | 
			
		||||
    revision = select_revision_or_version(repo_id, revision, version)
 | 
			
		||||
 | 
			
		||||
    package_name = package_name_from_repo_id(repo_id)
 | 
			
		||||
    variant = build_variant()
 | 
			
		||||
    universal_variant = universal_build_variant()
 | 
			
		||||
 | 
			
		||||
    if file_exists(
 | 
			
		||||
        repo_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        filename=f"build/{universal_variant}/{package_name}/__init__.py",
 | 
			
		||||
    ):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    return file_exists(
 | 
			
		||||
        repo_id,
 | 
			
		||||
        revision=revision,
 | 
			
		||||
        filename=f"build/{variant}/{package_name}/__init__.py",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
 | 
			
		||||
    """
 | 
			
		||||
    Get a pre-downloaded, locked kernel.
 | 
			
		||||
 | 
			
		||||
    If `lockfile` is not specified, the lockfile will be loaded from the
 | 
			
		||||
    caller's package metadata.
 | 
			
		||||
    If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_id (`str`):
 | 
			
		||||
            The Hub repository containing the kernel.
 | 
			
		||||
        lockfile (`Path`, *optional*):
 | 
			
		||||
            Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `ModuleType`: The imported kernel module.
 | 
			
		||||
    """
 | 
			
		||||
    if lockfile is None:
 | 
			
		||||
        locked_sha = _get_caller_locked_kernel(repo_id)
 | 
			
		||||
@ -194,7 +365,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
 | 
			
		||||
    """Get a kernel using a lock file."""
 | 
			
		||||
    """
 | 
			
		||||
    Get a kernel using a lock file.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        repo_id (`str`):
 | 
			
		||||
            The Hub repository containing the kernel.
 | 
			
		||||
        local_files_only (`bool`, *optional*, defaults to `False`):
 | 
			
		||||
            Whether to only use local files and not download from the Hub.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        `ModuleType`: The imported kernel module.
 | 
			
		||||
    """
 | 
			
		||||
    locked_sha = _get_caller_locked_kernel(repo_id)
 | 
			
		||||
 | 
			
		||||
    if locked_sha is None:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								src/kernels/wheel.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,186 @@
 | 
			
		||||
import email.policy
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from email.message import Message
 | 
			
		||||
from importlib.metadata import PackageNotFoundError, version
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    KERNELS_VERSION = version("kernels")
 | 
			
		||||
except PackageNotFoundError:
 | 
			
		||||
    KERNELS_VERSION = "unknown"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Metadata:
 | 
			
		||||
    name: str
 | 
			
		||||
    version: str
 | 
			
		||||
    cuda_version: Optional[str]
 | 
			
		||||
    cxx_abi_version: Optional[str]
 | 
			
		||||
    torch_version: Optional[str]
 | 
			
		||||
    os: Optional[str]
 | 
			
		||||
    platform: Optional[str]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_universal(self) -> bool:
 | 
			
		||||
        return self.platform is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_variant_to_wheel(
 | 
			
		||||
    repo_id: str,
 | 
			
		||||
    *,
 | 
			
		||||
    version: str,
 | 
			
		||||
    variant_path: Path,
 | 
			
		||||
    wheel_dir: Path,
 | 
			
		||||
    manylinux_version: str = "2.28",
 | 
			
		||||
    python_version: str = "3.9",
 | 
			
		||||
) -> Path:
 | 
			
		||||
    """
 | 
			
		||||
    Create a wheel file from the variant path.
 | 
			
		||||
    """
 | 
			
		||||
    name = repo_id.split("/")[-1].replace("_", "-")
 | 
			
		||||
    metadata = extract_metadata(name, version, variant_path)
 | 
			
		||||
    return build_wheel(
 | 
			
		||||
        metadata,
 | 
			
		||||
        variant_path=variant_path,
 | 
			
		||||
        wheel_dir=wheel_dir,
 | 
			
		||||
        manylinux_version=manylinux_version,
 | 
			
		||||
        python_version=python_version,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
 | 
			
		||||
    """
 | 
			
		||||
    Extract metadata from the variant path.
 | 
			
		||||
    """
 | 
			
		||||
    if variant_path.name == "torch-universal":
 | 
			
		||||
        return Metadata(
 | 
			
		||||
            name=name,
 | 
			
		||||
            version=version,
 | 
			
		||||
            cuda_version=None,
 | 
			
		||||
            cxx_abi_version=None,
 | 
			
		||||
            torch_version=None,
 | 
			
		||||
            os=None,
 | 
			
		||||
            platform=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if not variant_path.name.startswith("torch"):
 | 
			
		||||
        raise ValueError("Currently only conversion of Torch kernels is supported.")
 | 
			
		||||
 | 
			
		||||
    variant_parts = variant_path.name.removeprefix("torch").split("-")
 | 
			
		||||
    if len(variant_parts) != 5:
 | 
			
		||||
        raise ValueError(f"Invalid variant name: {variant_path.name}")
 | 
			
		||||
 | 
			
		||||
    torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
 | 
			
		||||
    cpp_abi_version = variant_parts[1].removeprefix("cxx")
 | 
			
		||||
    cuda_version = variant_parts[2].removeprefix("cu")
 | 
			
		||||
    platform = variant_parts[3].replace("-", "_")
 | 
			
		||||
    os = variant_parts[4]
 | 
			
		||||
 | 
			
		||||
    return Metadata(
 | 
			
		||||
        name=name,
 | 
			
		||||
        version=version,
 | 
			
		||||
        cuda_version=cuda_version,
 | 
			
		||||
        cxx_abi_version=cpp_abi_version,
 | 
			
		||||
        torch_version=torch_version,
 | 
			
		||||
        os=os,
 | 
			
		||||
        platform=platform,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_wheel(
 | 
			
		||||
    metadata: Metadata,
 | 
			
		||||
    *,
 | 
			
		||||
    variant_path: Path,
 | 
			
		||||
    wheel_dir: Path,
 | 
			
		||||
    manylinux_version: str = "2.28",
 | 
			
		||||
    python_version: str = "3.9",
 | 
			
		||||
) -> Path:
 | 
			
		||||
    """
 | 
			
		||||
    Build the wheel file.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        from wheel.wheelfile import WheelFile  # type: ignore
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        raise ImportError(
 | 
			
		||||
            "The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    name = metadata.name.replace("-", "_")
 | 
			
		||||
    python_version_flat = python_version.replace(".", "")
 | 
			
		||||
 | 
			
		||||
    if metadata.is_universal:
 | 
			
		||||
        python_tag = f"py{python_version_flat}"
 | 
			
		||||
        abi_tag = "none"
 | 
			
		||||
        platform_tag = "any"
 | 
			
		||||
        wheel_filename = (
 | 
			
		||||
            f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
 | 
			
		||||
        )
 | 
			
		||||
        dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
 | 
			
		||||
        root_is_purelib = "true"
 | 
			
		||||
        requires_dist_torch = "torch"
 | 
			
		||||
    else:
 | 
			
		||||
        python_tag = f"cp{python_version_flat}"
 | 
			
		||||
        abi_tag = "abi3"
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            metadata.torch_version is None
 | 
			
		||||
            or metadata.cuda_version is None
 | 
			
		||||
            or metadata.cxx_abi_version is None
 | 
			
		||||
            or metadata.os is None
 | 
			
		||||
            or metadata.platform is None
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
 | 
			
		||||
 | 
			
		||||
        if metadata.os == "linux":
 | 
			
		||||
            platform_tag = (
 | 
			
		||||
                f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
 | 
			
		||||
 | 
			
		||||
        wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
 | 
			
		||||
        dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
 | 
			
		||||
        root_is_purelib = "false"
 | 
			
		||||
        requires_dist_torch = f"torch=={metadata.torch_version}.*"
 | 
			
		||||
 | 
			
		||||
    wheel_path = wheel_dir / wheel_filename
 | 
			
		||||
 | 
			
		||||
    wheel_msg = Message(email.policy.compat32)
 | 
			
		||||
    wheel_msg.add_header("Wheel-Version", "1.0")
 | 
			
		||||
    wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
 | 
			
		||||
    wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
 | 
			
		||||
    wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
 | 
			
		||||
 | 
			
		||||
    metadata_msg = Message(email.policy.compat32)
 | 
			
		||||
    metadata_msg.add_header("Metadata-Version", "2.1")
 | 
			
		||||
    metadata_msg.add_header("Name", name)
 | 
			
		||||
    metadata_msg.add_header("Version", metadata.version)
 | 
			
		||||
    metadata_msg.add_header("Summary", f"{name} kernel")
 | 
			
		||||
    metadata_msg.add_header("Requires-Python", ">=3.9")
 | 
			
		||||
    metadata_msg.add_header("Requires-Dist", requires_dist_torch)
 | 
			
		||||
 | 
			
		||||
    source_pkg_dir = variant_path / name
 | 
			
		||||
 | 
			
		||||
    with WheelFile(wheel_path, "w") as wheel_file:
 | 
			
		||||
        for root, dirnames, filenames in os.walk(source_pkg_dir):
 | 
			
		||||
            for filename in filenames:
 | 
			
		||||
                if filename.endswith(".pyc"):
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                abs_filepath = os.path.join(root, filename)
 | 
			
		||||
                entry_name = os.path.relpath(abs_filepath, variant_path)
 | 
			
		||||
                wheel_file.write(abs_filepath, entry_name)
 | 
			
		||||
 | 
			
		||||
        wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
 | 
			
		||||
        wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
 | 
			
		||||
 | 
			
		||||
        metadata_path = os.path.join(dist_info_dir_name, "METADATA")
 | 
			
		||||
        wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
 | 
			
		||||
 | 
			
		||||
    return wheel_path
 | 
			
		||||
							
								
								
									
										41
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								tests/conftest.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
has_cuda = (
 | 
			
		||||
    hasattr(torch.version, "cuda")
 | 
			
		||||
    and torch.version.cuda is not None
 | 
			
		||||
    and torch.cuda.device_count() > 0
 | 
			
		||||
)
 | 
			
		||||
has_rocm = (
 | 
			
		||||
    hasattr(torch.version, "hip")
 | 
			
		||||
    and torch.version.hip is not None
 | 
			
		||||
    and torch.cuda.device_count() > 0
 | 
			
		||||
)
 | 
			
		||||
has_xpu = (
 | 
			
		||||
    hasattr(torch.version, "xpu")
 | 
			
		||||
    and torch.version.xpu is not None
 | 
			
		||||
    and torch.xpu.device_count() > 0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_addoption(parser):
 | 
			
		||||
    parser.addoption(
 | 
			
		||||
        "--token",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="run tests that require a token with write permissions",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_runtest_setup(item):
 | 
			
		||||
    if "cuda_only" in item.keywords and not has_cuda:
 | 
			
		||||
        pytest.skip("skipping CUDA-only test on host without CUDA")
 | 
			
		||||
    if "rocm_only" in item.keywords and not has_rocm:
 | 
			
		||||
        pytest.skip("skipping ROCm-only test on host without ROCm")
 | 
			
		||||
    if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
 | 
			
		||||
        pytest.skip("skipping macOS-only test on non-macOS platform")
 | 
			
		||||
    if "xpu_only" in item.keywords and not has_xpu:
 | 
			
		||||
        pytest.skip("skipping XPU-only test on host without XPU")
 | 
			
		||||
    if "token" in item.keywords and not item.config.getoption("--token"):
 | 
			
		||||
        pytest.skip("need --token option to run this test")
 | 
			
		||||
@ -1,54 +1,82 @@
 | 
			
		||||
[
 | 
			
		||||
  {
 | 
			
		||||
    "repo_id": "kernels-community/activation",
 | 
			
		||||
    "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
 | 
			
		||||
    "sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
 | 
			
		||||
    "variants": {
 | 
			
		||||
      "torch25-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
 | 
			
		||||
        "hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx11-cu121-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
 | 
			
		||||
        "hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx11-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
 | 
			
		||||
        "hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
 | 
			
		||||
        "hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu121-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
 | 
			
		||||
        "hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch25-cxx98-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
 | 
			
		||||
        "hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
 | 
			
		||||
        "hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
 | 
			
		||||
        "hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
 | 
			
		||||
        "hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
 | 
			
		||||
        "hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu124-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
 | 
			
		||||
        "hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch26-cxx98-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
 | 
			
		||||
        "hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu118-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu126-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu126-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu128-aarch64-linux": {
 | 
			
		||||
        "hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      },
 | 
			
		||||
      "torch27-cxx11-cu128-x86_64-linux": {
 | 
			
		||||
        "hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								tests/layer_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
[
 | 
			
		||||
  {
 | 
			
		||||
    "repo_id": "kernels-test/versions",
 | 
			
		||||
    "sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
 | 
			
		||||
    "variants": {
 | 
			
		||||
      "torch-universal": {
 | 
			
		||||
        "hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
 | 
			
		||||
        "hash_type": "git_lfs_concat"
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
]
 | 
			
		||||
							
								
								
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tests/layer_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,2 @@
 | 
			
		||||
[tool.kernels.dependencies]
 | 
			
		||||
"kernels-test/versions" = ">=0.1.0,<0.2.0"
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from kernels import get_kernel
 | 
			
		||||
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
@ -9,6 +9,25 @@ def kernel():
 | 
			
		||||
    return get_kernel("kernels-community/activation")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def local_kernel_path():
 | 
			
		||||
    package_name, path = install_kernel("kernels-community/activation", "main")
 | 
			
		||||
    # Path is the build variant path (build/torch-<...>), so the grandparent
 | 
			
		||||
    # is the kernel repository path.
 | 
			
		||||
    return package_name, path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def local_kernel(local_kernel_path):
 | 
			
		||||
    package_name, path = local_kernel_path
 | 
			
		||||
    return get_local_kernel(path.parent.parent, package_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def metal_kernel():
 | 
			
		||||
    return get_kernel("kernels-test/relu-metal")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def universal_kernel():
 | 
			
		||||
    return get_kernel("kernels-community/triton-scaled-mm")
 | 
			
		||||
@ -21,6 +40,7 @@ def device():
 | 
			
		||||
    return "cuda"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_gelu_fast(kernel, device):
 | 
			
		||||
    x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
@ -36,6 +56,100 @@ def test_gelu_fast(kernel, device):
 | 
			
		||||
    assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_local_kernel(local_kernel, device):
 | 
			
		||||
    x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
 | 
			
		||||
    local_kernel.gelu_fast(y, x)
 | 
			
		||||
 | 
			
		||||
    expected = torch.tensor(
 | 
			
		||||
        [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
 | 
			
		||||
        device=device,
 | 
			
		||||
        dtype=torch.float16,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_local_kernel_path_types(local_kernel_path, device):
 | 
			
		||||
    package_name, path = local_kernel_path
 | 
			
		||||
 | 
			
		||||
    # Top-level repo path
 | 
			
		||||
    # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
 | 
			
		||||
    kernel = get_local_kernel(path.parent.parent, package_name)
 | 
			
		||||
    x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
 | 
			
		||||
    kernel.gelu_fast(y, x)
 | 
			
		||||
    expected = torch.tensor(
 | 
			
		||||
        [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
 | 
			
		||||
        device=device,
 | 
			
		||||
        dtype=torch.float16,
 | 
			
		||||
    )
 | 
			
		||||
    assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
    # Build directory path
 | 
			
		||||
    # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
 | 
			
		||||
    kernel = get_local_kernel(path.parent.parent / "build", package_name)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
    kernel.gelu_fast(y, x)
 | 
			
		||||
    assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
    # Explicit package path
 | 
			
		||||
    # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
 | 
			
		||||
    kernel = get_local_kernel(path, package_name)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
    kernel.gelu_fast(y, x)
 | 
			
		||||
    assert torch.allclose(y, expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.darwin_only
 | 
			
		||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
 | 
			
		||||
def test_relu_metal(metal_kernel, dtype):
 | 
			
		||||
    x = torch.arange(-10, 10, dtype=dtype, device="mps")
 | 
			
		||||
    y = metal_kernel.relu(x)
 | 
			
		||||
    assert torch.allclose(y, torch.relu(x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "kernel_exists",
 | 
			
		||||
    [
 | 
			
		||||
        ("kernels-community/activation", "main", True),
 | 
			
		||||
        ("kernels-community/triton-layer-norm", "main", True),
 | 
			
		||||
        # Repo only contains Torch 2.4 kernels (and we don't
 | 
			
		||||
        # support/test against this version).
 | 
			
		||||
        ("kernels-test/only-torch-2.4", "main", False),
 | 
			
		||||
        ("google-bert/bert-base-uncased", "87565a309", False),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_has_kernel(kernel_exists):
 | 
			
		||||
    repo_id, revision, kernel = kernel_exists
 | 
			
		||||
    assert has_kernel(repo_id, revision=revision) == kernel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_version():
 | 
			
		||||
    kernel = get_kernel("kernels-test/versions")
 | 
			
		||||
    assert kernel.version() == "0.2.0"
 | 
			
		||||
    kernel = get_kernel("kernels-test/versions", version="<1.0.0")
 | 
			
		||||
    assert kernel.version() == "0.2.0"
 | 
			
		||||
    kernel = get_kernel("kernels-test/versions", version="<0.2.0")
 | 
			
		||||
    assert kernel.version() == "0.1.1"
 | 
			
		||||
    kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
 | 
			
		||||
    assert kernel.version() == "0.1.1"
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
 | 
			
		||||
        get_kernel("kernels-test/versions", version=">0.2.0")
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
 | 
			
		||||
        kernel = get_kernel(
 | 
			
		||||
            "kernels-test/versions", revision="v0.1.0", version="<1.0.0"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_universal_kernel(universal_kernel):
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
    A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
 | 
			
		||||
 | 
			
		||||
@ -16,18 +16,21 @@ def device():
 | 
			
		||||
    return "cuda"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_gelu_small(kernel, device, benchmark):
 | 
			
		||||
    x = torch.randn(32, 32, dtype=torch.float16, device=device)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
    benchmark(kernel.gelu_fast, y, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_gelu_medium(kernel, device, benchmark):
 | 
			
		||||
    x = torch.randn(128, 128, dtype=torch.float16, device=device)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
    benchmark(kernel.gelu_fast, y, x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_gelu_large(kernel, device, benchmark):
 | 
			
		||||
    x = torch.randn(512, 512, dtype=torch.float16, device=device)
 | 
			
		||||
    y = torch.empty_like(x)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								tests/test_doctest.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
import inspect
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from mktestdocs import check_docstring, get_codeblock_members
 | 
			
		||||
 | 
			
		||||
import kernels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def all_public_functions():
 | 
			
		||||
    function_list = inspect.getmembers(kernels, inspect.isfunction)
 | 
			
		||||
    return [func for _, func in function_list]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def all_public_classes():
 | 
			
		||||
    class_list = inspect.getmembers(kernels, inspect.isclass)
 | 
			
		||||
    return [cls for _, cls in class_list]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def all_public_class_members():
 | 
			
		||||
    members = get_codeblock_members(*all_public_classes())
 | 
			
		||||
    return members
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "func",
 | 
			
		||||
    all_public_functions(),
 | 
			
		||||
    ids=lambda d: d.__name__,
 | 
			
		||||
)
 | 
			
		||||
def test_func_docstring(func):
 | 
			
		||||
    check_docstring(obj=func)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "cls",
 | 
			
		||||
    all_public_classes(),
 | 
			
		||||
    ids=lambda d: d.__name__,
 | 
			
		||||
)
 | 
			
		||||
def test_class_docstring(cls):
 | 
			
		||||
    check_docstring(obj=cls)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "member", all_public_class_members(), ids=lambda d: d.__qualname__
 | 
			
		||||
)
 | 
			
		||||
def test_member_docstring(member):
 | 
			
		||||
    check_docstring(member)
 | 
			
		||||
							
								
								
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								tests/test_interval_tree.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,230 @@
 | 
			
		||||
import random
 | 
			
		||||
from typing import Generic, List, Optional, Tuple, TypeVar
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from kernels._interval_tree import IntervalTree, _Node
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SimpleIntervalStore(Generic[T]):
 | 
			
		||||
    """A simple O(n) implementation that stores intervals in a list."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.intervals: List[Tuple[int, int, T]] = []
 | 
			
		||||
 | 
			
		||||
    def insert(self, start: int, end: int, data: T) -> None:
 | 
			
		||||
        """Insert an interval into the store."""
 | 
			
		||||
        # Replace data if the interval already exists.
 | 
			
		||||
        for i, (existing_start, existing_end, existing_data) in enumerate(
 | 
			
		||||
            self.intervals
 | 
			
		||||
        ):
 | 
			
		||||
            if existing_start == start and existing_end == end:
 | 
			
		||||
                self.intervals[i] = (start, end, data)
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
        self.intervals.append((start, end, data))
 | 
			
		||||
 | 
			
		||||
    def find_smallest_interval(self, point: int) -> Optional[T]:
 | 
			
		||||
        """Find the best match using linear search."""
 | 
			
		||||
        matches = []
 | 
			
		||||
        for start, end, data in self.intervals:
 | 
			
		||||
            if start <= point <= end:
 | 
			
		||||
                matches.append((start, end, data))
 | 
			
		||||
 | 
			
		||||
        if not matches:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # Return the smallest interval, sort by memory location when
 | 
			
		||||
        # there are multiple matches with the same interval size. This
 | 
			
		||||
        # mirrors the ordering in the intervan tree.
 | 
			
		||||
        best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
 | 
			
		||||
        return best_match[2]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_balanced(tree: IntervalTree[T]) -> bool:
 | 
			
		||||
    """Check if the AVL tree is properly balanced."""
 | 
			
		||||
 | 
			
		||||
    def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
 | 
			
		||||
        if node is None:
 | 
			
		||||
            return True, 0
 | 
			
		||||
 | 
			
		||||
        # Left and right subtrees should be balanced.
 | 
			
		||||
        left_balanced, left_height = check_balance(node.left)
 | 
			
		||||
        if not left_balanced:
 | 
			
		||||
            return False, -1
 | 
			
		||||
 | 
			
		||||
        right_balanced, right_height = check_balance(node.right)
 | 
			
		||||
        if not right_balanced:
 | 
			
		||||
            return False, -1
 | 
			
		||||
 | 
			
		||||
        # The difference in height should not exceed 1.
 | 
			
		||||
        if abs(left_height - right_height) > 1:
 | 
			
		||||
            return False, -1
 | 
			
		||||
 | 
			
		||||
        # Check if the height is correct.
 | 
			
		||||
        expected_height = 1 + max(left_height, right_height)
 | 
			
		||||
        if node.height != expected_height:
 | 
			
		||||
            return False, -1
 | 
			
		||||
 | 
			
		||||
        return True, expected_height
 | 
			
		||||
 | 
			
		||||
    balanced, _ = check_balance(tree.root)
 | 
			
		||||
    return balanced
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def populated_tree() -> IntervalTree[str]:
 | 
			
		||||
    """Provides a pre-populated IntervalTree for testing."""
 | 
			
		||||
    tree = IntervalTree[str]()
 | 
			
		||||
    kernels = [
 | 
			
		||||
        (80, 89, "Kernel_A_General_80_89"),
 | 
			
		||||
        (86, 89, "Kernel_B_Ampere_86_89"),
 | 
			
		||||
        (80, 86, "Kernel_C_Older_Ampere_80_86"),
 | 
			
		||||
        (70, 75, "Kernel_D_Volta_70_75"),
 | 
			
		||||
        (86, 87, "Kernel_E_Specific_86_87"),
 | 
			
		||||
    ]
 | 
			
		||||
    for start, end, name in kernels:
 | 
			
		||||
        tree.insert(start, end, name)
 | 
			
		||||
    return tree
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
 | 
			
		||||
    # Check that the smallest inteval is selected when there are
 | 
			
		||||
    # multiple matching intervals.
 | 
			
		||||
    assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_find_single_match(populated_tree):
 | 
			
		||||
    assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
 | 
			
		||||
    assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_no_match_outside_all_ranges(populated_tree):
 | 
			
		||||
    # Check that no interval is found when the value is out of range
 | 
			
		||||
    # (too small/too large).
 | 
			
		||||
    assert populated_tree.find_smallest_interval(65) is None
 | 
			
		||||
    assert populated_tree.find_smallest_interval(95) is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_no_match_in_gap_between_ranges(populated_tree):
 | 
			
		||||
    # Check that no interval is found when the value is between two
 | 
			
		||||
    # intervals.
 | 
			
		||||
    assert populated_tree.find_smallest_interval(78) is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_boundary_conditions_start_and_end(populated_tree):
 | 
			
		||||
    # Test exact upper/lower bounds of intervals.
 | 
			
		||||
    assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
 | 
			
		||||
    assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_empty_tree():
 | 
			
		||||
    # Searching in an empty tree should return None.
 | 
			
		||||
    empty_tree = IntervalTree[str]()
 | 
			
		||||
    assert empty_tree.find_smallest_interval(100) is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_multiple_equally_specific_matches():
 | 
			
		||||
    # Check that we pick the match in a stable way when there is are
 | 
			
		||||
    # multiple matching intervals with the same size.
 | 
			
		||||
    tree = IntervalTree[str]()
 | 
			
		||||
    str1 = "First_Narrow_Kernel"
 | 
			
		||||
    str2 = "Second_Narrow_Kernel"
 | 
			
		||||
    tree.insert(10, 20, "Wide_Kernel")
 | 
			
		||||
    tree.insert(12, 17, str1)
 | 
			
		||||
    tree.insert(14, 19, str2)
 | 
			
		||||
 | 
			
		||||
    if id(str1) < id(str2):
 | 
			
		||||
        assert tree.find_smallest_interval(15) == str1
 | 
			
		||||
    else:
 | 
			
		||||
        assert tree.find_smallest_interval(15) == str2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_property_based_interval_tree():
 | 
			
		||||
    # Quick-check property-based testing:
 | 
			
		||||
    #
 | 
			
		||||
    # - Verify that the tree is balanced after each insertion.
 | 
			
		||||
    # - Verify the query against a simple list-based implementation.
 | 
			
		||||
 | 
			
		||||
    random.seed(42)  # For reproducible tests
 | 
			
		||||
 | 
			
		||||
    test_points = list(range(0, 101))
 | 
			
		||||
 | 
			
		||||
    for _ in range(5):
 | 
			
		||||
        tree = IntervalTree[str]()
 | 
			
		||||
        simple = SimpleIntervalStore[str]()
 | 
			
		||||
 | 
			
		||||
        intervals = []
 | 
			
		||||
        for i in range(100):
 | 
			
		||||
            start = random.randint(0, 90)
 | 
			
		||||
            end = random.randint(start, 100)
 | 
			
		||||
            data = f"interval_{i}_s{start}_e{end}"
 | 
			
		||||
            intervals.append((start, end, data))
 | 
			
		||||
 | 
			
		||||
        for i, (start, end, data) in enumerate(intervals):
 | 
			
		||||
            tree.insert(start, end, data)
 | 
			
		||||
            simple.insert(start, end, data)
 | 
			
		||||
 | 
			
		||||
            # Check that tree is still balanced
 | 
			
		||||
            assert is_balanced(
 | 
			
		||||
                tree
 | 
			
		||||
            ), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
 | 
			
		||||
 | 
			
		||||
            for point in test_points:
 | 
			
		||||
                tree_result = tree.find_smallest_interval(point)
 | 
			
		||||
                simple_result = simple.find_smallest_interval(point)
 | 
			
		||||
 | 
			
		||||
                assert tree_result == simple_result, (
 | 
			
		||||
                    f"Mismatch for point {point} after inserting {i+1} intervals. "
 | 
			
		||||
                    f"Tree: {tree_result}, Simple: {simple_result}. "
 | 
			
		||||
                    f"Last inserted: ({start}, {end})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_property_based_edge_cases():
 | 
			
		||||
    random.seed(123)
 | 
			
		||||
 | 
			
		||||
    tree = IntervalTree[str]()
 | 
			
		||||
    simple = SimpleIntervalStore[str]()
 | 
			
		||||
 | 
			
		||||
    # Single-point intervals.
 | 
			
		||||
    for i in range(10):
 | 
			
		||||
        point = random.randint(0, 100)
 | 
			
		||||
        data = f"single_point_{i}_{point}"
 | 
			
		||||
        tree.insert(point, point, data)
 | 
			
		||||
        simple.insert(point, point, data)
 | 
			
		||||
 | 
			
		||||
        assert is_balanced(
 | 
			
		||||
            tree
 | 
			
		||||
        ), f"Tree unbalanced after inserting single point {point}"
 | 
			
		||||
 | 
			
		||||
        # Test the exact point and neighbors
 | 
			
		||||
        for test_point in [point - 1, point, point + 1]:
 | 
			
		||||
            if 0 <= test_point <= 100:
 | 
			
		||||
                tree_result = tree.find_smallest_interval(test_point)
 | 
			
		||||
                simple_result = simple.find_smallest_interval(test_point)
 | 
			
		||||
                assert tree_result == simple_result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unique_intervals_override():
 | 
			
		||||
    """Test that inserting an interval with the same start/end overrides the previous value."""
 | 
			
		||||
    tree = IntervalTree[str]()
 | 
			
		||||
 | 
			
		||||
    tree.insert(10, 20, "original_value")
 | 
			
		||||
    assert tree.find_smallest_interval(15) == "original_value"
 | 
			
		||||
 | 
			
		||||
    tree.insert(10, 20, "new_value")
 | 
			
		||||
    assert tree.find_smallest_interval(15) == "new_value"
 | 
			
		||||
 | 
			
		||||
    tree.insert(10, 25, "different_interval")
 | 
			
		||||
    results = tree.search(15)
 | 
			
		||||
    assert "new_value" in results
 | 
			
		||||
    assert "different_interval" in results
 | 
			
		||||
    assert len(results) == 2
 | 
			
		||||
 | 
			
		||||
    tree.insert(10, 20, "final_value")
 | 
			
		||||
    assert tree.find_smallest_interval(15) == "final_value"
 | 
			
		||||
 | 
			
		||||
    assert is_balanced(tree)
 | 
			
		||||
@ -1,8 +1,18 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
from kernels import load_kernel
 | 
			
		||||
from kernels.cli import download_kernels
 | 
			
		||||
from kernels.layer import (
 | 
			
		||||
    LockedLayerRepository,
 | 
			
		||||
    Mode,
 | 
			
		||||
    kernelize,
 | 
			
		||||
    use_kernel_forward_from_hub,
 | 
			
		||||
    use_kernel_mapping,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Mock download arguments class.
 | 
			
		||||
@ -17,8 +27,34 @@ def test_download_all_hash_validation():
 | 
			
		||||
    download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.cuda_only
 | 
			
		||||
def test_load_locked():
 | 
			
		||||
    project_dir = Path(__file__).parent / "kernel_locking"
 | 
			
		||||
    # Also validates that hashing works correctly.
 | 
			
		||||
    download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
 | 
			
		||||
    load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_layer_locked():
 | 
			
		||||
    project_dir = Path(__file__).parent / "layer_locking"
 | 
			
		||||
 | 
			
		||||
    @use_kernel_forward_from_hub("Version")
 | 
			
		||||
    class Version(nn.Module):
 | 
			
		||||
        def forward(self) -> str:
 | 
			
		||||
            return "0.0.0"
 | 
			
		||||
 | 
			
		||||
    version = Version()
 | 
			
		||||
 | 
			
		||||
    with use_kernel_mapping(
 | 
			
		||||
        {
 | 
			
		||||
            "Version": {
 | 
			
		||||
                "cuda": LockedLayerRepository(
 | 
			
		||||
                    repo_id="kernels-test/versions",
 | 
			
		||||
                    layer_name="Version",
 | 
			
		||||
                    lockfile=project_dir / "kernels.lock",
 | 
			
		||||
                )
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    ):
 | 
			
		||||
        version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
 | 
			
		||||
        assert version() == "0.1.1"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										115
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,115 @@
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import tempfile
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from huggingface_hub import delete_repo, model_info
 | 
			
		||||
 | 
			
		||||
from kernels.cli import upload_kernels
 | 
			
		||||
 | 
			
		||||
REPO_ID = "valid_org/kernels-upload-test"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PY_CONTENT = """\
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    print("Hello from torch-universal!")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class UploadArgs:
 | 
			
		||||
    kernel_dir: None
 | 
			
		||||
    repo_id: None
 | 
			
		||||
    private: False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def next_filename(path: Path) -> Path:
 | 
			
		||||
    """
 | 
			
		||||
    Given a path like foo_2050.py, return foo_2051.py.
 | 
			
		||||
    """
 | 
			
		||||
    m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
 | 
			
		||||
    if not m:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    prefix, number, suffix = m.groups()
 | 
			
		||||
    new_number = str(int(number) + 1).zfill(len(number))
 | 
			
		||||
    return path.with_name(f"{prefix}{new_number}{suffix}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_filename_to_change(repo_filenames):
 | 
			
		||||
    for f in repo_filenames:
 | 
			
		||||
        if "foo" in f and f.endswith(".py"):
 | 
			
		||||
            filename_to_change = os.path.basename(f)
 | 
			
		||||
            break
 | 
			
		||||
    assert filename_to_change
 | 
			
		||||
    return filename_to_change
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
 | 
			
		||||
    try:
 | 
			
		||||
        repo_info = model_info(repo_id=repo_id, files_metadata=True)
 | 
			
		||||
        repo_siblings = repo_info.siblings
 | 
			
		||||
        if repo_siblings is not None:
 | 
			
		||||
            return [f.rfilename for f in repo_siblings]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("No repo siblings found.")
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logging.error(f"Error connecting to the Hub: {e}.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.token
 | 
			
		||||
@pytest.mark.is_staging_test
 | 
			
		||||
def test_kernel_upload_works_as_expected():
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        path = f"{tmpdir}/build/torch-universal/upload_test"
 | 
			
		||||
        build_dir = Path(path)
 | 
			
		||||
        build_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        script_path = build_dir / "foo.py"
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    assert any(str(script_path.name) for f in repo_filenames)
 | 
			
		||||
    delete_repo(repo_id=REPO_ID)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.token
 | 
			
		||||
@pytest.mark.is_staging_test
 | 
			
		||||
def test_kernel_upload_deletes_as_expected():
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        path = f"{tmpdir}/build/torch-universal/upload_test"
 | 
			
		||||
        build_dir = Path(path)
 | 
			
		||||
        build_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        script_path = build_dir / "foo_2025.py"
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    filename_to_change = get_filename_to_change(repo_filenames)
 | 
			
		||||
 | 
			
		||||
    with tempfile.TemporaryDirectory() as tmpdir:
 | 
			
		||||
        path = f"{tmpdir}/build/torch-universal/upload_test"
 | 
			
		||||
        build_dir = Path(path)
 | 
			
		||||
        build_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        changed_filename = next_filename(Path(filename_to_change))
 | 
			
		||||
        script_path = build_dir / changed_filename
 | 
			
		||||
        script_path.write_text(PY_CONTENT)
 | 
			
		||||
        upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
 | 
			
		||||
 | 
			
		||||
    repo_filenames = get_filenames_from_a_repo(REPO_ID)
 | 
			
		||||
    assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
 | 
			
		||||
    assert not any(
 | 
			
		||||
        str(filename_to_change) in k for k in repo_filenames
 | 
			
		||||
    ), f"{repo_filenames=}"
 | 
			
		||||
    delete_repo(repo_id=REPO_ID)
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user