mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-31 19:54:28 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			release-0.
			...
			api-docs
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ae43772a67 | |||
| c02f88cd2a | |||
| a3db6f437c | 
							
								
								
									
										19
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								.github/workflows/build_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,19 @@ | |||||||
|  | name: Build documentation | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     paths: | ||||||
|  |       - "docs/source/**" | ||||||
|  |     branches: | ||||||
|  |       - main | ||||||
|  |       - doc-builder* | ||||||
|  |       - v*-release | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main | ||||||
|  |     with: | ||||||
|  |       commit_sha: ${{ github.sha }} | ||||||
|  |       package: kernels | ||||||
|  |     secrets: | ||||||
|  |       hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} | ||||||
							
								
								
									
										18
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								.github/workflows/build_pr_documentation.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,18 @@ | |||||||
|  | name: Build PR Documentation | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   pull_request: | ||||||
|  |     paths: | ||||||
|  |       - "docs/source/**" | ||||||
|  |  | ||||||
|  | concurrency: | ||||||
|  |   group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||||||
|  |   cancel-in-progress: true | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main | ||||||
|  |     with: | ||||||
|  |       commit_sha: ${{ github.event.pull_request.head.sha }} | ||||||
|  |       pr_number: ${{ github.event.number }} | ||||||
|  |       package: kernels | ||||||
							
								
								
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										120
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,120 +0,0 @@ | |||||||
| name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI |  | ||||||
|  |  | ||||||
| on: push |  | ||||||
|  |  | ||||||
| jobs: |  | ||||||
|   build: |  | ||||||
|     name: Build distribution 📦 |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|  |  | ||||||
|     steps: |  | ||||||
|       - uses: actions/checkout@v4 |  | ||||||
|         with: |  | ||||||
|           persist-credentials: false |  | ||||||
|       - name: Set up Python |  | ||||||
|         uses: actions/setup-python@v5 |  | ||||||
|         with: |  | ||||||
|           python-version: "3.9" |  | ||||||
|       - name: Install pypa/build |  | ||||||
|         run: >- |  | ||||||
|           python3 -m |  | ||||||
|           pip install |  | ||||||
|           build |  | ||||||
|           --user |  | ||||||
|       - name: Build a binary wheel and a source tarball |  | ||||||
|         run: python3 -m build |  | ||||||
|       - name: Store the distribution packages |  | ||||||
|         uses: actions/upload-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           name: python-package-distributions |  | ||||||
|           path: dist/ |  | ||||||
|  |  | ||||||
|   publish-to-pypi: |  | ||||||
|     name: >- |  | ||||||
|       Publish Python 🐍 distribution 📦 to PyPI |  | ||||||
|     if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes |  | ||||||
|     needs: |  | ||||||
|       - build |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|     environment: |  | ||||||
|       name: pypi |  | ||||||
|       url: https://pypi.org/p/kernels |  | ||||||
|     permissions: |  | ||||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing |  | ||||||
|  |  | ||||||
|     steps: |  | ||||||
|       - name: Download all the dists |  | ||||||
|         uses: actions/download-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           name: python-package-distributions |  | ||||||
|           path: dist/ |  | ||||||
|       - name: Publish distribution 📦 to PyPI |  | ||||||
|         uses: pypa/gh-action-pypi-publish@release/v1 |  | ||||||
|  |  | ||||||
|   github-release: |  | ||||||
|     name: >- |  | ||||||
|       Sign the Python 🐍 distribution 📦 with Sigstore |  | ||||||
|       and upload them to GitHub Release |  | ||||||
|     needs: |  | ||||||
|       - publish-to-pypi |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|  |  | ||||||
|     permissions: |  | ||||||
|       contents: write # IMPORTANT: mandatory for making GitHub Releases |  | ||||||
|       id-token: write # IMPORTANT: mandatory for sigstore |  | ||||||
|  |  | ||||||
|     steps: |  | ||||||
|       - name: Download all the dists |  | ||||||
|         uses: actions/download-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           name: python-package-distributions |  | ||||||
|           path: dist/ |  | ||||||
|       - name: Sign the dists with Sigstore |  | ||||||
|         uses: sigstore/gh-action-sigstore-python@v3.0.0 |  | ||||||
|         with: |  | ||||||
|           inputs: >- |  | ||||||
|             ./dist/*.tar.gz |  | ||||||
|             ./dist/*.whl |  | ||||||
|       - name: Create GitHub Release |  | ||||||
|         env: |  | ||||||
|           GITHUB_TOKEN: ${{ github.token }} |  | ||||||
|         run: >- |  | ||||||
|           gh release create |  | ||||||
|           "$GITHUB_REF_NAME" |  | ||||||
|           --repo "$GITHUB_REPOSITORY" |  | ||||||
|           --notes "" |  | ||||||
|       - name: Upload artifact signatures to GitHub Release |  | ||||||
|         env: |  | ||||||
|           GITHUB_TOKEN: ${{ github.token }} |  | ||||||
|         # Upload to GitHub Release using the `gh` CLI. |  | ||||||
|         # `dist/` contains the built packages, and the |  | ||||||
|         # sigstore-produced signatures and certificates. |  | ||||||
|         run: >- |  | ||||||
|           gh release upload |  | ||||||
|           "$GITHUB_REF_NAME" dist/** |  | ||||||
|           --repo "$GITHUB_REPOSITORY" |  | ||||||
|  |  | ||||||
|   publish-to-testpypi: |  | ||||||
|     name: Publish Python 🐍 distribution 📦 to TestPyPI |  | ||||||
|     needs: |  | ||||||
|       - build |  | ||||||
|     runs-on: ubuntu-latest |  | ||||||
|  |  | ||||||
|     environment: |  | ||||||
|       name: testpypi |  | ||||||
|       url: https://test.pypi.org/p/kernels |  | ||||||
|  |  | ||||||
|     permissions: |  | ||||||
|       id-token: write # IMPORTANT: mandatory for trusted publishing |  | ||||||
|  |  | ||||||
|     steps: |  | ||||||
|       - name: Download all the dists |  | ||||||
|         uses: actions/download-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           name: python-package-distributions |  | ||||||
|           path: dist/ |  | ||||||
|       - name: Publish distribution 📦 to TestPyPI |  | ||||||
|         uses: pypa/gh-action-pypi-publish@release/v1 |  | ||||||
|         with: |  | ||||||
|           repository-url: https://test.pypi.org/legacy/ |  | ||||||
|           skip-existing: true # Only upload when the version is unique. |  | ||||||
							
								
								
									
										14
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -24,7 +24,7 @@ jobs: | |||||||
|       max-parallel: 4 |       max-parallel: 4 | ||||||
|       matrix: |       matrix: | ||||||
|         python-version: ["3.10", "3.12"] |         python-version: ["3.10", "3.12"] | ||||||
|         torch-version: ["2.6.0", "2.7.0"] |         torch-version: ["2.5.1", "2.6.0"] | ||||||
|  |  | ||||||
|     env: |     env: | ||||||
|       UV_PYTHON_PREFERENCE: only-managed |       UV_PYTHON_PREFERENCE: only-managed | ||||||
| @ -53,18 +53,6 @@ jobs: | |||||||
|       - name: Run tests |       - name: Run tests | ||||||
|         run: uv run pytest tests |         run: uv run pytest tests | ||||||
|  |  | ||||||
|       - name: Check kernel conversion |  | ||||||
|         run: | |  | ||||||
|           uv pip install wheel |  | ||||||
|           uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1 |  | ||||||
|           uv pip install triton_layer_norm-0.0.1*.whl |  | ||||||
|           uv run python -c "import triton_layer_norm" |  | ||||||
|  |  | ||||||
|       - name: Check README generation |  | ||||||
|         # For now, just checks that generation doesn't fail. |  | ||||||
|         run: | |  | ||||||
|           uv run kernels generate-readme kernels-community/triton-layer-norm |  | ||||||
|  |  | ||||||
|       - name: Import check without torch |       - name: Import check without torch | ||||||
|         run: | |         run: | | ||||||
|           uv pip uninstall torch |           uv pip uninstall torch | ||||||
|  | |||||||
| @ -61,5 +61,4 @@ the Hub. | |||||||
| - [Environment variables](docs/env.md) | - [Environment variables](docs/env.md) | ||||||
| - [Using kernels in a Docker container](docs/docker.md) | - [Using kernels in a Docker container](docs/docker.md) | ||||||
| - [Kernel requirements](docs/kernel-requirements.md) | - [Kernel requirements](docs/kernel-requirements.md) | ||||||
| - [Frequently Asked Questions](docs/faq.md) |  | ||||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
|  | |||||||
							
								
								
									
										13
									
								
								docs/faq.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								docs/faq.md
									
									
									
									
									
								
							| @ -1,13 +0,0 @@ | |||||||
| # FAQ |  | ||||||
|  |  | ||||||
| ## Why is the kernelization step needed? |  | ||||||
|  |  | ||||||
| In earlier versions of `kernels`, a layer's `forward` was replaced by |  | ||||||
| `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The |  | ||||||
| new `forward` would dispatch to a kernel based on the device type, |  | ||||||
| whether a model was training, etc. However, this approach was |  | ||||||
| fundamentally incompatible with `torch.compile` since it relied |  | ||||||
| on data-dependent branching. |  | ||||||
|  |  | ||||||
| To avoid branching, we have to make dispatch decisions ahead of time, |  | ||||||
| which is what the `kernelize` function does. |  | ||||||
							
								
								
									
										197
									
								
								docs/layers.md
									
									
									
									
									
								
							
							
						
						
									
										197
									
								
								docs/layers.md
									
									
									
									
									
								
							| @ -1,197 +0,0 @@ | |||||||
| # Layers |  | ||||||
|  |  | ||||||
| A kernel can provide layers in addition to kernel functions. A layer from |  | ||||||
| the Hub can replace the `forward` method of an existing layer for a certain |  | ||||||
| device type. This makes it possible to provide more performant kernels for |  | ||||||
| existing layers. |  | ||||||
|  |  | ||||||
| See [Kernel requirements](kernel-requirements.md) for more information the |  | ||||||
| requirements of Hub layers. |  | ||||||
|  |  | ||||||
| ## Making a layer extensible with kernels from the hub |  | ||||||
|  |  | ||||||
| ### Using a decorator |  | ||||||
|  |  | ||||||
| A layer can be made extensible with the `use_kernel_forward_from_hub` |  | ||||||
| decorator. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| @use_kernel_forward_from_hub("SiluAndMul") |  | ||||||
| class SiluAndMul(nn.Module): |  | ||||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: |  | ||||||
|         d = input.shape[-1] // 2 |  | ||||||
|         return F.silu(input[..., :d]) * input[..., d:] |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The decorator does not change the behavior of the class -- it annotates |  | ||||||
| the class with the given name (here `SiluAndMul`). The `kernelize` function |  | ||||||
| described below uses this name to look up kernels for the layer. |  | ||||||
|  |  | ||||||
| ### External layers |  | ||||||
|  |  | ||||||
| An existing layer that does not (yet) have the `use_kernel_forward_from_hub` |  | ||||||
| decorator can be made extensible using the `replace_kernel_forward_from_hub` |  | ||||||
| function: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| from somelibrary import SiluAndMul |  | ||||||
|  |  | ||||||
| replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| **Warning:** we strongly recommend using layers with a decorator, since |  | ||||||
| it signifies that the maintainer intends to keep the `forward` signature |  | ||||||
| compatible with layers from the hub. |  | ||||||
|  |  | ||||||
| ## Kernelizing a model |  | ||||||
|  |  | ||||||
| A model will not use Hub kernels by default, even if it contains extensible |  | ||||||
| layers. To enable the use of Hub kernels in the model, it needs to be |  | ||||||
| 'kernelized' using the `kernelize` function. This function traverses the |  | ||||||
| model graph and replaces the `forward` methods of extensible layers for which |  | ||||||
| Hub kernels are registered. Kernelize can be used as follows: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| model = MyModel(...) |  | ||||||
| model = kernelize(model, mode=Mode.INFERENCE) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The `mode` specifies that the model will be used in inference. Similarly, |  | ||||||
| you can ask `kernelize` to prepare the model for training: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| model = MyModel(...) |  | ||||||
| model = kernelize(model, mode=Mode.TRAINING) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| **Note:** the `kernelize` function modifies the model in-place, the model |  | ||||||
| itself is returned as a convenience. |  | ||||||
|  |  | ||||||
| ### Kernel device |  | ||||||
|  |  | ||||||
| Kernels can be registered per device type. For instance, separate `cuda` and |  | ||||||
| `metal` kernels could be registered for the name `SiluAndMul`. By default, |  | ||||||
| `kernelize` will try to infer the device type from the model's parameters. |  | ||||||
| You can pass the device type to `kernelize` if the device type cannot be |  | ||||||
| inferred (e.g. because the model has no parameters): |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| model = MyModel(...) |  | ||||||
| model = kernelize(model, device="cuda", mode=Mode.INFERENCE) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| ### `torch.compile` |  | ||||||
|  |  | ||||||
| Not all Hub kernels support `torch.compile`. If you want to compile a model |  | ||||||
| after kernelizing it, you need to add this to the mode. You can use the |  | ||||||
| set union (`|`) operator to add `TORCH_COMPILE` to the mode: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| model = MyModel(...) |  | ||||||
| model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| ### Fallback `forward` |  | ||||||
|  |  | ||||||
| If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered |  | ||||||
| kernel does not support backward passes or `torch.compile` respectively, |  | ||||||
| `kernenize` will fall back to the original, non-kernelized, layer. You |  | ||||||
| can let `kernelize` raise an exception instead by using `use_fallback=False`: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| model = MyModel(...) |  | ||||||
| model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| This can be useful if you want to guarantee that Hub kernels are used. |  | ||||||
|  |  | ||||||
| ## Registering a hub kernel for a layer |  | ||||||
|  |  | ||||||
| `kernelize` relies on kernel mappings to find Hub kernels for layers. |  | ||||||
| Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on |  | ||||||
| the Hub. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| kernel_layer_mapping = { |  | ||||||
|     "SiluAndMul": { |  | ||||||
|         "cuda": LayerRepository( |  | ||||||
|             repo_id="kernels-community/activation", |  | ||||||
|             layer_name="SiluAndMul", |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| You can register such a mapping using `register_kernel_mapping`: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| register_kernel_mapping(kernel_layer_mapping) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| This will register the kernel mapping in the current context, which is |  | ||||||
| normally global. It is recommended to scope the mapping to where it is |  | ||||||
| used with the `use_kernel_mapping` context manager: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| with use_kernel_mapping(kernel_layer_mapping): |  | ||||||
|     # Use the layer for which the mapping is applied. |  | ||||||
|     model = kernelize(model) |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| This ensures that the mapping is not active anymore outside the |  | ||||||
| `with`-scope. |  | ||||||
|  |  | ||||||
| ### Registering kernels for specific modes |  | ||||||
|  |  | ||||||
| You might want to register two different kernels for a particular layer, |  | ||||||
| where one kernel is optimized for a specific mode. You can do so by |  | ||||||
| registering layer repositories for specific modes. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| kernel_layer_mapping = { |  | ||||||
|     "SiluAndMul": { |  | ||||||
|         "cuda": { |  | ||||||
|           Mode.INFERENCE: LayerRepository( |  | ||||||
|               repo_id="kernels-community/activation-inference-optimized", |  | ||||||
|               layer_name="SiluAndMul", |  | ||||||
|           ), |  | ||||||
|           Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository( |  | ||||||
|               repo_id="kernels-community/activation-training-optimized", |  | ||||||
|               layer_name="SiluAndMul", |  | ||||||
|           ), |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The kernels will match exactly on the mode. So, for instance in the example above, no kernel |  | ||||||
| layer is used when the `mode` passed to `kernelize` is |  | ||||||
| `Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to |  | ||||||
| register a kernel to be used when the mode does not match any of the |  | ||||||
| modes in the mapping, you can use the special `Mode.DEFAULT` mode to do |  | ||||||
| so. For example: |  | ||||||
|  |  | ||||||
| ```python |  | ||||||
| kernel_layer_mapping = { |  | ||||||
|     "SiluAndMul": { |  | ||||||
|         "cuda": { |  | ||||||
|           Mode.DEFAULT: LayerRepository( |  | ||||||
|               repo_id="kernels-community/activation", |  | ||||||
|               layer_name="SiluAndMul", |  | ||||||
|           ), |  | ||||||
|           Mode.INFERENCE: LayerRepository( |  | ||||||
|               repo_id="kernels-community/activation-inference-optimized", |  | ||||||
|               layer_name="SiluAndMul", |  | ||||||
|           ), |  | ||||||
|           Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository( |  | ||||||
|               repo_id="kernels-community/activation-training-optimized", |  | ||||||
|               layer_name="SiluAndMul", |  | ||||||
|           ), |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| In this case, modes other than `Mode.INFERENCE` and |  | ||||||
| `Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using |  | ||||||
| `kernels-community/activation`. |  | ||||||
							
								
								
									
										26
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								docs/source/_toctree.yml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | |||||||
|  | - sections: | ||||||
|  |     - local: index | ||||||
|  |       title: Introduction | ||||||
|  |     - local: installation | ||||||
|  |       title: Installation | ||||||
|  |   title: Getting started | ||||||
|  | - sections: | ||||||
|  |     - local: basic_usage | ||||||
|  |       title: Basic Usage | ||||||
|  |     - local: layers | ||||||
|  |       title: Using Layers | ||||||
|  |     - local: locking | ||||||
|  |       title: Locking Kernel Versions | ||||||
|  |     - local: env | ||||||
|  |       title: Environment Variables | ||||||
|  |   title: Usage Guide | ||||||
|  | - sections: | ||||||
|  |     - local: api/kernels | ||||||
|  |       title: Kernels | ||||||
|  |     - local: api/layers | ||||||
|  |       title: Layers | ||||||
|  |   title: API Reference | ||||||
|  | - sections: | ||||||
|  |     - local: kernel_requirements | ||||||
|  |       title: Kernel Requirements | ||||||
|  |   title: Developer Guide | ||||||
							
								
								
									
										21
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								docs/source/api/kernels.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,21 @@ | |||||||
|  | # Kernels API Reference | ||||||
|  |  | ||||||
|  | ## Main Functions | ||||||
|  |  | ||||||
|  | ### get_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.get_kernel | ||||||
|  |  | ||||||
|  | ### has_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.has_kernel | ||||||
|  |  | ||||||
|  | ## Loading locked kernels | ||||||
|  |  | ||||||
|  | ### load_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.load_kernel | ||||||
|  |  | ||||||
|  | ### get_locked_kernel | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.get_locked_kernel | ||||||
							
								
								
									
										31
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								docs/source/api/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,31 @@ | |||||||
|  | # Layers API Reference | ||||||
|  |  | ||||||
|  | ## Making layers kernel-aware | ||||||
|  |  | ||||||
|  | ### use_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.use_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | ### replace_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.replace_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  | ## Registering kernel mappings | ||||||
|  |  | ||||||
|  | ### use_kernel_mapping | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.use_kernel_mapping | ||||||
|  |  | ||||||
|  | ### register_kernel_mapping | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.register_kernel_mapping | ||||||
|  |  | ||||||
|  | ## Classes | ||||||
|  |  | ||||||
|  | ### LayerRepository | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.LayerRepository | ||||||
|  |  | ||||||
|  | ### Device | ||||||
|  |  | ||||||
|  | [[autodoc]] kernels.Device | ||||||
							
								
								
									
										34
									
								
								docs/source/basic_usage.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								docs/source/basic_usage.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | |||||||
|  | # Basic Usage | ||||||
|  |  | ||||||
|  | ## Loading Kernels | ||||||
|  |  | ||||||
|  | Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | import torch | ||||||
|  | from kernels import get_kernel | ||||||
|  |  | ||||||
|  | # Download optimized kernels from the Hugging Face hub | ||||||
|  | activation = get_kernel("kernels-community/activation") | ||||||
|  |  | ||||||
|  | # Create a random tensor | ||||||
|  | x = torch.randn((10, 10), dtype=torch.float16, device="cuda") | ||||||
|  |  | ||||||
|  | # Run the kernel | ||||||
|  | y = torch.empty_like(x) | ||||||
|  | activation.gelu_fast(y, x) | ||||||
|  |  | ||||||
|  | print(y) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Checking Kernel Availability | ||||||
|  |  | ||||||
|  | You can check if a specific kernel is available for your environment: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from kernels import has_kernel | ||||||
|  |  | ||||||
|  | # Check if kernel is available for current environment | ||||||
|  | is_available = has_kernel("kernels-community/activation") | ||||||
|  | print(f"Kernel available: {is_available}") | ||||||
|  | ``` | ||||||
							
								
								
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								docs/source/index.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,20 @@ | |||||||
|  | # Kernels | ||||||
|  |  | ||||||
|  | <div align="center"> | ||||||
|  | <img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo"> | ||||||
|  | </div> | ||||||
|  |  | ||||||
|  | The Kernel Hub allows Python libraries and applications to load compute | ||||||
|  | kernels directly from the [Hub](https://hf.co/). To support this kind | ||||||
|  | of dynamic loading, Hub kernels differ from traditional Python kernel | ||||||
|  | packages in that they are made to be: | ||||||
|  |  | ||||||
|  | - **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`. | ||||||
|  | - **Unique**: multiple versions of the same kernel can be loaded in the | ||||||
|  |   same Python process. | ||||||
|  | - **Compatible**: kernels must support all recent versions of Python and | ||||||
|  |   the different PyTorch build configurations (various CUDA versions | ||||||
|  |   and C++ ABIs). Furthermore, older C library versions must be supported. | ||||||
|  |  | ||||||
|  | You can [search for kernels](https://huggingface.co/models?other=kernel) on | ||||||
|  | the Hub. | ||||||
							
								
								
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								docs/source/installation.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | |||||||
|  | # Installation | ||||||
|  |  | ||||||
|  | Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | pip install kernels | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | # Using kernels in a Docker container | ||||||
|  |  | ||||||
|  | build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||||
|  | docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||||
|  | ``` | ||||||
| @ -1,11 +1,8 @@ | |||||||
| # Kernel requirements | # Kernel requirements | ||||||
| 
 | 
 | ||||||
| Kernels on the Hub must fulfill the requirements outlined on this page. By | Kernels on the Hub must fulfill the requirements outlined on this page. | ||||||
| ensuring kernels are compliant, they can be used on a wide range of Linux |  | ||||||
| systems and Torch builds. |  | ||||||
| 
 |  | ||||||
| You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
| to build compliant kernels. | to build conforming kernels. | ||||||
| 
 | 
 | ||||||
| ## Directory layout | ## Directory layout | ||||||
| 
 | 
 | ||||||
| @ -13,21 +10,34 @@ A kernel repository on the Hub must contain a `build` directory. This | |||||||
| directory contains build variants of a kernel in the form of directories | directory contains build variants of a kernel in the form of directories | ||||||
| following the template | following the template | ||||||
| `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | ||||||
| For example `build/torch26-cxx98-cu118-x86_64-linux`. | For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently | ||||||
|  | recommended build variants are: | ||||||
| 
 | 
 | ||||||
| Each variant directory must contain a single directory with the same name | - `torch25-cxx11-cu118-x86_64-linux` | ||||||
|  | - `torch25-cxx11-cu121-x86_64-linux` | ||||||
|  | - `torch25-cxx11-cu124-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu118-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu121-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu118-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu126-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu118-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu126-x86_64-linux` | ||||||
|  | 
 | ||||||
|  | This list will be updated as new PyTorch versions are released. Kernels | ||||||
|  | that are in pure Python (e.g. Triton kernels) only need to provide a | ||||||
|  | single build variant: | ||||||
|  | 
 | ||||||
|  | - `torch-universal` | ||||||
|  | 
 | ||||||
|  | Each variant directory should contain a single directory with the same name | ||||||
| as the repository (replacing `-` by `_`). For instance, kernels in the | as the repository (replacing `-` by `_`). For instance, kernels in the | ||||||
| `kernels-community/activation` repository have a directories like | `kernels-community/activation` repository have a directories like | ||||||
| `build/<variant>/activation`. This directory | `build/<variant>/activation`. This directory | ||||||
| must be a Python package with an `__init__.py` file. | must be a Python package with an `__init__.py` file. | ||||||
| 
 | 
 | ||||||
| ## Build variants |  | ||||||
| 
 |  | ||||||
| A kernel can be compliant for a specific compute framework (e.g. CUDA) or |  | ||||||
| architecture (e.g. x86_64). For compliance with a compute framework and |  | ||||||
| architecture combination, all the variants from the [build variant list](https://github.com/huggingface/kernel-builder/blob/main/docs/build-variants.md) |  | ||||||
| must be available for that combination. |  | ||||||
| 
 |  | ||||||
| ## Versioning | ## Versioning | ||||||
| 
 | 
 | ||||||
| Kernels are versioned on the Hub using Git tags. Version tags must be of | Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||||
| @ -37,14 +47,8 @@ to resolve the version constraints. | |||||||
| ## Native Python module | ## Native Python module | ||||||
| 
 | 
 | ||||||
| Kernels will typically contain a native Python module with precompiled | Kernels will typically contain a native Python module with precompiled | ||||||
| compute kernels and bindings. This module must fulfill the requirements | compute kernels and bindings. This module must fulfill the following | ||||||
| outlined in this section. For all operating systems, a kernel must not | requirements: | ||||||
| have dynamic library dependencies outside: |  | ||||||
| 
 |  | ||||||
| - Torch; |  | ||||||
| - CUDA/ROCm libraries installed as dependencies of Torch. |  | ||||||
| 
 |  | ||||||
| ### Linux |  | ||||||
| 
 | 
 | ||||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||||
|   for compatibility with Python 3.9 and later. |   for compatibility with Python 3.9 and later. | ||||||
| @ -56,18 +60,12 @@ have dynamic library dependencies outside: | |||||||
|   - CXXABI 1.3.11 |   - CXXABI 1.3.11 | ||||||
|   - GCC 7.0.0 |   - GCC 7.0.0 | ||||||
| 
 | 
 | ||||||
| These requirement can be checked with the ABI checker (see below). |   These requirement can be checked with the ABI checker (see below). | ||||||
| 
 | 
 | ||||||
| ### macOS | - No dynamic library dependencies outside: | ||||||
| 
 | 
 | ||||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) |   - Torch; | ||||||
|   for compatibility with Python 3.9 and later. |   - CUDA/ROCm libraries installed as dependencies of Torch. | ||||||
| - macOS deployment target 15.0. |  | ||||||
| - Metal 3.0 (`-std=metal3.0`). |  | ||||||
| 
 |  | ||||||
| The ABI3 requirement can be checked with the ABI checker (see below). |  | ||||||
| 
 |  | ||||||
| ### ABI checker |  | ||||||
| 
 | 
 | ||||||
| The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with | ||||||
| [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | [`kernel-abi-check`](https://crates.io/crates/kernel-abi-check): | ||||||
| @ -121,12 +119,9 @@ requirements: | |||||||
| - The `forward` method has a signature that is compatible with the | - The `forward` method has a signature that is compatible with the | ||||||
|   `forward` method that it is extending. |   `forward` method that it is extending. | ||||||
| 
 | 
 | ||||||
| There are two exceptions to the _no class variables rule_: | The only exception to the _no class variables rule_ is addition of a | ||||||
| 
 | `has_backward` class variable. This variable is used to indicate whether | ||||||
| 1. The `has_backward` variable can be used to indicate whether the layer has | the layer has a backward pass implemented (`True` when absent). | ||||||
|    a backward pass implemented (`True` when absent). |  | ||||||
| 2. The `can_torch_compile` variable can be used to indicate whether the layer |  | ||||||
|    supports `torch.compile` (`False` when absent). |  | ||||||
| 
 | 
 | ||||||
| This is an example of a pure layer: | This is an example of a pure layer: | ||||||
| 
 | 
 | ||||||
							
								
								
									
										79
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								docs/source/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | # Layers | ||||||
|  |  | ||||||
|  | A kernel can provide layers in addition to kernel functions. A layer from | ||||||
|  | the Hub can replace the `forward` method of an existing layer for a certain | ||||||
|  | device type. This makes it possible to provide more performant kernels for | ||||||
|  | existing layers. | ||||||
|  |  | ||||||
|  | See [Kernel requirements](kernel-requirements.md) for more information the | ||||||
|  | requirements of Hub layers. | ||||||
|  |  | ||||||
|  | ## Making a layer extensible with kernels from the hub | ||||||
|  |  | ||||||
|  | ### Using a decorator | ||||||
|  |  | ||||||
|  | A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||||
|  | decorator. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | @use_kernel_forward_from_hub("SiluAndMul") | ||||||
|  | class SiluAndMul(nn.Module): | ||||||
|  |     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         d = input.shape[-1] // 2 | ||||||
|  |         return F.silu(input[..., :d]) * input[..., d:] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The decorator changes the layer, so that other implementations of the `forward` | ||||||
|  | method can be registered using the name `SiluAndMul`. | ||||||
|  |  | ||||||
|  | ### External layers | ||||||
|  |  | ||||||
|  | An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||||
|  | decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function. | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from somelibrary import SiluAndMul | ||||||
|  |  | ||||||
|  | replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The `register_kernel_mapping` call maps the name `SiluAndMul` to actual | ||||||
|  | hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer) | ||||||
|  | section for more information. | ||||||
|  |  | ||||||
|  | **Warning:** we strongly recommend using layers with a decorator, since | ||||||
|  | it signifies that the maintainer intends to keep the `forward` signature | ||||||
|  | compatible with layers from the hub. | ||||||
|  |  | ||||||
|  | ## Registering a hub kernel for a layer | ||||||
|  |  | ||||||
|  | Once a layer is made extensible, users can register hub kernels for it | ||||||
|  | by name using the `register_kernel_mapping` function. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             revision="layers", | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This will register the kernel mapping in the current context, which is | ||||||
|  | normally global. It is recommended to scope the mapping to where it is | ||||||
|  | used with the `use_kernel_mapping` context manager: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | with use_kernel_mapping(kernel_layer_mapping): | ||||||
|  |     # Use the layer for which the mapping is applied. | ||||||
|  |     ... | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This ensures that the mapping is not active anymore outside the | ||||||
|  | `with`-scope. | ||||||
							
								
								
									
										55
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										55
									
								
								flake.lock
									
									
									
										generated
									
									
									
								
							| @ -51,38 +51,18 @@ | |||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
|     }, |     }, | ||||||
|     "hf-nix": { |  | ||||||
|       "inputs": { |  | ||||||
|         "flake-compat": "flake-compat", |  | ||||||
|         "flake-utils": "flake-utils_2", |  | ||||||
|         "nixpkgs": "nixpkgs" |  | ||||||
|       }, |  | ||||||
|       "locked": { |  | ||||||
|         "lastModified": 1750775451, |  | ||||||
|         "narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=", |  | ||||||
|         "owner": "huggingface", |  | ||||||
|         "repo": "hf-nix", |  | ||||||
|         "rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7", |  | ||||||
|         "type": "github" |  | ||||||
|       }, |  | ||||||
|       "original": { |  | ||||||
|         "owner": "huggingface", |  | ||||||
|         "repo": "hf-nix", |  | ||||||
|         "type": "github" |  | ||||||
|       } |  | ||||||
|     }, |  | ||||||
|     "nixpkgs": { |     "nixpkgs": { | ||||||
|       "locked": { |       "locked": { | ||||||
|         "lastModified": 1747820358, |         "lastModified": 1737453259, | ||||||
|         "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", |         "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", | ||||||
|         "owner": "danieldk", |         "owner": "danieldk", | ||||||
|         "repo": "nixpkgs", |         "repo": "nixpkgs", | ||||||
|         "rev": "d3c1681180717528068082103bf323147de6ab0b", |         "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       }, |       }, | ||||||
|       "original": { |       "original": { | ||||||
|         "owner": "danieldk", |         "owner": "danieldk", | ||||||
|         "ref": "cudatoolkit-12.9-kernel-builder", |         "ref": "outlines-v0.1.4-tgi", | ||||||
|         "repo": "nixpkgs", |         "repo": "nixpkgs", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
| @ -90,11 +70,11 @@ | |||||||
|     "root": { |     "root": { | ||||||
|       "inputs": { |       "inputs": { | ||||||
|         "flake-utils": "flake-utils", |         "flake-utils": "flake-utils", | ||||||
|         "hf-nix": "hf-nix", |  | ||||||
|         "nixpkgs": [ |         "nixpkgs": [ | ||||||
|           "hf-nix", |           "tgi-nix", | ||||||
|           "nixpkgs" |           "nixpkgs" | ||||||
|         ] |         ], | ||||||
|  |         "tgi-nix": "tgi-nix" | ||||||
|       } |       } | ||||||
|     }, |     }, | ||||||
|     "systems": { |     "systems": { | ||||||
| @ -126,6 +106,27 @@ | |||||||
|         "repo": "default", |         "repo": "default", | ||||||
|         "type": "github" |         "type": "github" | ||||||
|       } |       } | ||||||
|  |     }, | ||||||
|  |     "tgi-nix": { | ||||||
|  |       "inputs": { | ||||||
|  |         "flake-compat": "flake-compat", | ||||||
|  |         "flake-utils": "flake-utils_2", | ||||||
|  |         "nixpkgs": "nixpkgs" | ||||||
|  |       }, | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1741617161, | ||||||
|  |         "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=", | ||||||
|  |         "owner": "huggingface", | ||||||
|  |         "repo": "text-generation-inference-nix", | ||||||
|  |         "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "huggingface", | ||||||
|  |         "ref": "kernels-0.2.0", | ||||||
|  |         "repo": "text-generation-inference-nix", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|   }, |   }, | ||||||
|   "root": "root", |   "root": "root", | ||||||
|  | |||||||
							
								
								
									
										15
									
								
								flake.nix
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								flake.nix
									
									
									
									
									
								
							| @ -1,7 +1,7 @@ | |||||||
| { | { | ||||||
|   inputs = { |   inputs = { | ||||||
|     hf-nix.url = "github:huggingface/hf-nix"; |     tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0"; | ||||||
|     nixpkgs.follows = "hf-nix/nixpkgs"; |     nixpkgs.follows = "tgi-nix/nixpkgs"; | ||||||
|     flake-utils.url = "github:numtide/flake-utils"; |     flake-utils.url = "github:numtide/flake-utils"; | ||||||
|   }; |   }; | ||||||
|   outputs = |   outputs = | ||||||
| @ -9,21 +9,21 @@ | |||||||
|       self, |       self, | ||||||
|       nixpkgs, |       nixpkgs, | ||||||
|       flake-utils, |       flake-utils, | ||||||
|       hf-nix, |       tgi-nix, | ||||||
|     }: |     }: | ||||||
|     flake-utils.lib.eachDefaultSystem ( |     flake-utils.lib.eachDefaultSystem ( | ||||||
|       system: |       system: | ||||||
|       let |       let | ||||||
|         pkgs = import nixpkgs { |         pkgs = import nixpkgs { | ||||||
|           inherit system; |           inherit system; | ||||||
|           config = hf-nix.lib.config system; |           inherit (tgi-nix.lib) config; | ||||||
|           overlays = [ |           overlays = [ | ||||||
|             hf-nix.overlays.default |             tgi-nix.overlays.default | ||||||
|           ]; |           ]; | ||||||
|         }; |         }; | ||||||
|       in |       in | ||||||
|       { |       { | ||||||
|         formatter = pkgs.nixfmt-tree; |         formatter = pkgs.nixfmt-rfc-style; | ||||||
|         devShells = with pkgs; rec { |         devShells = with pkgs; rec { | ||||||
|           default = mkShell { |           default = mkShell { | ||||||
|             buildInputs = |             buildInputs = | ||||||
| @ -34,13 +34,10 @@ | |||||||
|                 ruff |                 ruff | ||||||
|               ] |               ] | ||||||
|               ++ (with python3.pkgs; [ |               ++ (with python3.pkgs; [ | ||||||
|                 docutils |  | ||||||
|                 huggingface-hub |                 huggingface-hub | ||||||
|                 pytest |                 pytest | ||||||
|                 pytest-benchmark |                 pytest-benchmark | ||||||
|                 pyyaml |  | ||||||
|                 torch |                 torch | ||||||
|                 types-pyyaml |  | ||||||
|                 venvShellHook |                 venvShellHook | ||||||
|               ]); |               ]); | ||||||
|  |  | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| [project] | [project] | ||||||
| name = "kernels" | name = "kernels" | ||||||
| version = "0.7.0" | version = "0.4.4" | ||||||
| description = "Download compute kernels" | description = "Download compute kernels" | ||||||
| authors = [ | authors = [ | ||||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, |   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||||
| @ -14,7 +14,6 @@ requires-python = ">= 3.9" | |||||||
| dependencies = [ | dependencies = [ | ||||||
|   "huggingface_hub>=0.26.0,<1.0", |   "huggingface_hub>=0.26.0,<1.0", | ||||||
|   "packaging>=20.0", |   "packaging>=20.0", | ||||||
|   "pyyaml>=6", |  | ||||||
|   "tomli>=2.0; python_version<'3.11'", |   "tomli>=2.0; python_version<'3.11'", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| @ -24,16 +23,17 @@ build-backend = "setuptools.build_meta" | |||||||
|  |  | ||||||
| [dependency-groups] | [dependency-groups] | ||||||
| dev = [ | dev = [ | ||||||
|   "mypy >= 1.15.0", |   "mypy == 1.14.1", | ||||||
|   "pytest >=8", |   "pytest >=8", | ||||||
|   # Whatever version is compatible with pytest. |   # Whatever version is compatible with pytest. | ||||||
|   "pytest-benchmark", |   "pytest-benchmark", | ||||||
|   "torch >=2.5", |   "torch >=2.5", | ||||||
|   "types-pyyaml" |  | ||||||
| ] | ] | ||||||
|  |  | ||||||
| [project.optional-dependencies] | [project.optional-dependencies] | ||||||
| torch = ["torch"] | docs = [ | ||||||
|  |   "hf-doc-builder", | ||||||
|  | ] | ||||||
|  |  | ||||||
| [project.scripts] | [project.scripts] | ||||||
| kernels = "kernels.cli:main" | kernels = "kernels.cli:main" | ||||||
| @ -41,7 +41,6 @@ kernels = "kernels.cli:main" | |||||||
| [project.entry-points."egg_info.writers"] | [project.entry-points."egg_info.writers"] | ||||||
| "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||||
|  |  | ||||||
|  |  | ||||||
| [tool.ruff] | [tool.ruff] | ||||||
| exclude = [ | exclude = [ | ||||||
|   ".eggs", |   ".eggs", | ||||||
|  | |||||||
| @ -1,4 +0,0 @@ | |||||||
| [pytest] |  | ||||||
| markers = |  | ||||||
|     darwin_only: marks tests that should only run on macOS |  | ||||||
|     linux_only: marks tests that should only run on Linux |  | ||||||
| @ -1,8 +1,8 @@ | |||||||
|  | import importlib.metadata | ||||||
|  |  | ||||||
| from kernels.layer import ( | from kernels.layer import ( | ||||||
|     Device, |     Device, | ||||||
|     LayerRepository, |     LayerRepository, | ||||||
|     Mode, |  | ||||||
|     kernelize, |  | ||||||
|     register_kernel_mapping, |     register_kernel_mapping, | ||||||
|     replace_kernel_forward_from_hub, |     replace_kernel_forward_from_hub, | ||||||
|     use_kernel_forward_from_hub, |     use_kernel_forward_from_hub, | ||||||
| @ -10,7 +10,6 @@ from kernels.layer import ( | |||||||
| ) | ) | ||||||
| from kernels.utils import ( | from kernels.utils import ( | ||||||
|     get_kernel, |     get_kernel, | ||||||
|     get_local_kernel, |  | ||||||
|     get_locked_kernel, |     get_locked_kernel, | ||||||
|     has_kernel, |     has_kernel, | ||||||
|     install_kernel, |     install_kernel, | ||||||
| @ -18,18 +17,17 @@ from kernels.utils import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "Device", |  | ||||||
|     "LayerRepository", |  | ||||||
|     "Mode", |  | ||||||
|     "get_kernel", |     "get_kernel", | ||||||
|     "get_local_kernel", |  | ||||||
|     "get_locked_kernel", |     "get_locked_kernel", | ||||||
|     "has_kernel", |     "has_kernel", | ||||||
|     "install_kernel", |  | ||||||
|     "kernelize", |  | ||||||
|     "load_kernel", |     "load_kernel", | ||||||
|     "register_kernel_mapping", |     "install_kernel", | ||||||
|     "replace_kernel_forward_from_hub", |  | ||||||
|     "use_kernel_forward_from_hub", |     "use_kernel_forward_from_hub", | ||||||
|     "use_kernel_mapping", |     "use_kernel_mapping", | ||||||
|  |     "register_kernel_mapping", | ||||||
|  |     "replace_kernel_forward_from_hub", | ||||||
|  |     "LayerRepository", | ||||||
|  |     "Device", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | __version__ = importlib.metadata.version("kernels") | ||||||
|  | |||||||
| @ -1,751 +0,0 @@ | |||||||
| # coding=utf-8 |  | ||||||
| # Copyright 2021 The HuggingFace Team. All rights reserved. |  | ||||||
| # |  | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| # you may not use this file except in compliance with the License. |  | ||||||
| # You may obtain a copy of the License at |  | ||||||
| # |  | ||||||
| #     http://www.apache.org/licenses/LICENSE-2.0 |  | ||||||
| # |  | ||||||
| # Unless required by applicable law or agreed to in writing, software |  | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| # See the License for the specific language governing permissions and |  | ||||||
| # limitations under the License. |  | ||||||
|  |  | ||||||
| # Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py |  | ||||||
|  |  | ||||||
| import re |  | ||||||
|  |  | ||||||
| # Re pattern to catch things inside ` ` in :obj:`thing`. |  | ||||||
| _re_obj = re.compile(r":obj:`([^`]+)`") |  | ||||||
| # Re pattern to catch things inside ` ` in :math:`thing`. |  | ||||||
| _re_math = re.compile(r":math:`([^`]+)`") |  | ||||||
| # Re pattern to catch things between single backquotes. |  | ||||||
| _re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)") |  | ||||||
| # Re pattern to catch things between double backquotes. |  | ||||||
| _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)") |  | ||||||
| # Re pattern to catch things inside ` ` in :func/class/meth:`thing`. |  | ||||||
| _re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_rst_formatting(text): |  | ||||||
|     """ |  | ||||||
|     Convert rst syntax for formatting to markdown in a given text. |  | ||||||
|     """ |  | ||||||
|     # Remove :class:, :func: and :meth: markers. To code-links and put double backquotes |  | ||||||
|     # (to not be caught by the italic conversion). |  | ||||||
|     text = _re_func_class.sub(r"[``\1``]", text) |  | ||||||
|     # Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes |  | ||||||
|     # (to not be caught by the italic conversion). |  | ||||||
|     text = _re_obj.sub(r"``\1``", text) |  | ||||||
|     # Remove :math: markers. |  | ||||||
|     text = _re_math.sub(r"\\\\(\1\\\\)", text) |  | ||||||
|     # Convert content in single backquotes to italic. |  | ||||||
|     text = _re_single_backquotes.sub(r"\1*\2*\3", text) |  | ||||||
|     # Convert content in double backquotes to single backquotes. |  | ||||||
|     text = _re_double_backquotes.sub(r"\1`\2`\3", text) |  | ||||||
|     # Remove remaining :: |  | ||||||
|     text = re.sub(r"::\n", "", text) |  | ||||||
|  |  | ||||||
|     # Remove new lines inside blocks in backsticks as they will be kept. |  | ||||||
|     lines = text.split("\n") |  | ||||||
|     in_code = False |  | ||||||
|     text = None |  | ||||||
|     for line in lines: |  | ||||||
|         if in_code: |  | ||||||
|             splits = line.split("`") |  | ||||||
|             in_code = len(splits) > 1 and len(splits) % 2 == 1 |  | ||||||
|             if len(splits) == 1: |  | ||||||
|                 # Some forgotten lone backstick |  | ||||||
|                 text += "\n" + line |  | ||||||
|             else: |  | ||||||
|                 text += " " + line.lstrip() |  | ||||||
|         else: |  | ||||||
|             if text is not None: |  | ||||||
|                 text += "\n" + line |  | ||||||
|             else: |  | ||||||
|                 text = line |  | ||||||
|             splits = line.split("`") |  | ||||||
|             in_code = len(splits) % 2 == 0 |  | ||||||
|     return text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Re pattern to catch description and url in links of the form `description <url>`_. |  | ||||||
| _re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+") |  | ||||||
| # Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_. |  | ||||||
| _re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`") |  | ||||||
| # Re pattern to catch reference in links of the form :doc:`reference`. |  | ||||||
| _re_simple_doc = re.compile(r":doc:`([^`<]*)`") |  | ||||||
| # Re pattern to catch description and reference in links of the form :doc:`description <reference>`. |  | ||||||
| _re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`") |  | ||||||
| # Re pattern to catch reference in links of the form :ref:`reference`. |  | ||||||
| _re_simple_ref = re.compile(r":ref:`([^`<]*)`") |  | ||||||
| # Re pattern to catch description and reference in links of the form :ref:`description <reference>`. |  | ||||||
| _re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_rst_links(text, page_info): |  | ||||||
|     """ |  | ||||||
|     Convert the rst links in text to markdown. |  | ||||||
|     """ |  | ||||||
|     if "package_name" not in page_info: |  | ||||||
|         raise ValueError("`page_info` must contain at least the package_name.") |  | ||||||
|     package_name = page_info["package_name"] |  | ||||||
|     version = page_info.get("version", "main") |  | ||||||
|     language = page_info.get("language", "en") |  | ||||||
|     no_prefix = page_info.get("no_prefix", False) |  | ||||||
|  |  | ||||||
|     prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/" |  | ||||||
|     # Links of the form :doc:`page` |  | ||||||
|     text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text) |  | ||||||
|     # Links of the form :doc:`text <page>` |  | ||||||
|     text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text) |  | ||||||
|  |  | ||||||
|     if "page" in page_info and not no_prefix: |  | ||||||
|         page = str(page_info["page"]) |  | ||||||
|         if page.endswith(".html"): |  | ||||||
|             page = page[:-5] |  | ||||||
|         prefix = f"{prefix}{page}" |  | ||||||
|     else: |  | ||||||
|         prefix = "" |  | ||||||
|     # Refs of the form :ref:`page` |  | ||||||
|     text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text) |  | ||||||
|     # Refs of the form :ref:`text <page>` |  | ||||||
|     text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text) |  | ||||||
|  |  | ||||||
|     # Links with a prefix |  | ||||||
|     # TODO: when it exists, use the API to deal with prefix links properly. |  | ||||||
|     prefix = f"https://github.com/huggingface/{package_name}/tree/main/" |  | ||||||
|     text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text) |  | ||||||
|     # Other links |  | ||||||
|     text = _re_links.sub(r"[\1](\2)", text) |  | ||||||
|     # Relative links or Transformers links need to remove the .html |  | ||||||
|     if ( |  | ||||||
|         "(https://https://huggingface.co/" in text |  | ||||||
|         or re.search(r"\(\.+/", text) is not None |  | ||||||
|     ): |  | ||||||
|         text = text.replace(".html", "") |  | ||||||
|     return text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Re pattern that catches examples blocks of the form `Example::`. |  | ||||||
| _re_example = re.compile(r"^\s*(\S.*)::\s*$") |  | ||||||
| # Re pattern that catches rst blocks of the form `.. block_name::`. |  | ||||||
| _re_block = re.compile(r"^\s*\.\.\s+(\S+)::") |  | ||||||
| # Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`. |  | ||||||
| _re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_empty_line(line): |  | ||||||
|     return len(line) == 0 or line.isspace() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def find_indent(line): |  | ||||||
|     """ |  | ||||||
|     Returns the number of spaces that start a line indent. |  | ||||||
|     """ |  | ||||||
|     search = re.search(r"^(\s*)(?:\S|$)", line) |  | ||||||
|     if search is None: |  | ||||||
|         return 0 |  | ||||||
|     return len(search.groups()[0]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _re_rst_option = re.compile(r"^\s*:(\S+):(.*)$") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_special_chars(text): |  | ||||||
|     """ |  | ||||||
|     Converts { and < that have special meanings in MDX. |  | ||||||
|     """ |  | ||||||
|     text = text.replace("{", "&lcub;") |  | ||||||
|     # We don't want to replace those by the HTML code, so we temporarily set them at LTHTML |  | ||||||
|     text = re.sub( |  | ||||||
|         r"<(img|br|hr|Youtube)", r"LTHTML\1", text |  | ||||||
|     )  # html void elements with no closing counterpart |  | ||||||
|     _re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL) |  | ||||||
|     while _re_lt_html.search(text): |  | ||||||
|         text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text) |  | ||||||
|     text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&lt;\2", text) |  | ||||||
|     text = text.replace("LTHTML", "<") |  | ||||||
|     return text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_options(block_content): |  | ||||||
|     """ |  | ||||||
|     Parses the option in some rst block content. |  | ||||||
|     """ |  | ||||||
|     block_lines = block_content.split("\n") |  | ||||||
|     block_indent = find_indent(block_lines[0]) |  | ||||||
|     current_option = None |  | ||||||
|     result = {} |  | ||||||
|     for line in block_lines: |  | ||||||
|         if _re_rst_option.search(line) is not None: |  | ||||||
|             current_option, value = _re_rst_option.search(line).groups() |  | ||||||
|             result[current_option] = value.lstrip() |  | ||||||
|         elif find_indent(line) > block_indent: |  | ||||||
|             result[current_option] += " " + line.lstrip() |  | ||||||
|  |  | ||||||
|     return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def apply_min_indent(text, min_indent): |  | ||||||
|     """ |  | ||||||
|     Make sure all lines in a text are have a minimum indentation. |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         text (`str`): The text to treat. |  | ||||||
|         min_indent (`int`): The minimal indentation. |  | ||||||
|  |  | ||||||
|     Returns: |  | ||||||
|         `str`: The processed text. |  | ||||||
|     """ |  | ||||||
|     lines = text.split("\n") |  | ||||||
|     idx = 0 |  | ||||||
|     while idx < len(lines): |  | ||||||
|         if is_empty_line(lines[idx]): |  | ||||||
|             idx += 1 |  | ||||||
|             continue |  | ||||||
|         indent = find_indent(lines[idx]) |  | ||||||
|         if indent < min_indent: |  | ||||||
|             while idx < len(lines) and ( |  | ||||||
|                 find_indent(lines[idx]) >= indent or is_empty_line(lines[idx]) |  | ||||||
|             ): |  | ||||||
|                 if not is_empty_line(lines[idx]): |  | ||||||
|                     lines[idx] = " " * (min_indent - indent) + lines[idx] |  | ||||||
|                 idx += 1 |  | ||||||
|         else: |  | ||||||
|             idx += 1 |  | ||||||
|  |  | ||||||
|     return "\n".join(lines) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_rst_blocks(text, page_info): |  | ||||||
|     """ |  | ||||||
|     Converts rst special blocks (examples, notes) into MDX. |  | ||||||
|     """ |  | ||||||
|     if "package_name" not in page_info: |  | ||||||
|         raise ValueError("`page_info` must contain at least the package_name.") |  | ||||||
|     package_name = page_info["package_name"] |  | ||||||
|     version = page_info.get("version", "main") |  | ||||||
|     language = page_info.get("language", "en") |  | ||||||
|  |  | ||||||
|     lines = text.split("\n") |  | ||||||
|     idx = 0 |  | ||||||
|     new_lines = [] |  | ||||||
|     while idx < len(lines): |  | ||||||
|         block_type = None |  | ||||||
|         block_info = None |  | ||||||
|         if _re_block.search(lines[idx]) is not None: |  | ||||||
|             block_type = _re_block.search(lines[idx]).groups()[0] |  | ||||||
|             if _re_block_info.search(lines[idx]) is not None: |  | ||||||
|                 block_info = _re_block_info.search(lines[idx]).groups()[0] |  | ||||||
|         elif _re_example.search(lines[idx]) is not None: |  | ||||||
|             block_type = "code-block-example" |  | ||||||
|             block_info = "python" |  | ||||||
|             example_name = _re_example.search(lines[idx]).groups()[0] |  | ||||||
|             new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n") |  | ||||||
|         elif lines[idx].strip() == "..": |  | ||||||
|             block_type = "comment" |  | ||||||
|         elif lines[idx].strip() == "::": |  | ||||||
|             block_type = "code-block" |  | ||||||
|  |  | ||||||
|         if block_type is not None: |  | ||||||
|             block_indent = find_indent(lines[idx]) |  | ||||||
|             # Find the next nonempty line |  | ||||||
|             idx += 1 |  | ||||||
|             while idx < len(lines) and is_empty_line(lines[idx]): |  | ||||||
|                 idx += 1 |  | ||||||
|             # Grab the indent of the return line, this block will stop when we unindent under it (or has already) |  | ||||||
|             example_indent = ( |  | ||||||
|                 find_indent(lines[idx]) if idx < len(lines) else block_indent |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             if example_indent == block_indent: |  | ||||||
|                 block_content = "" |  | ||||||
|             else: |  | ||||||
|                 block_lines = [] |  | ||||||
|                 while idx < len(lines) and ( |  | ||||||
|                     is_empty_line(lines[idx]) |  | ||||||
|                     or find_indent(lines[idx]) >= example_indent |  | ||||||
|                 ): |  | ||||||
|                     block_lines.append(lines[idx][example_indent:]) |  | ||||||
|                     idx += 1 |  | ||||||
|                 block_content = "\n".join(block_lines) |  | ||||||
|  |  | ||||||
|             if block_type in ["code", "code-block"]: |  | ||||||
|                 prefix = "```" if block_info is None else f"```{block_info}" |  | ||||||
|                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n") |  | ||||||
|             elif block_type == "code-block-example": |  | ||||||
|                 prefix = f"<example>```{block_info}" |  | ||||||
|                 new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>") |  | ||||||
|             elif block_type == "note": |  | ||||||
|                 new_lines.append( |  | ||||||
|                     apply_min_indent( |  | ||||||
|                         f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             elif block_type == "warning": |  | ||||||
|                 new_lines.append( |  | ||||||
|                     apply_min_indent( |  | ||||||
|                         "<Tip warning={true}>\n\n" |  | ||||||
|                         + f"{block_content.strip()}\n\n</Tip>\n", |  | ||||||
|                         block_indent, |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             elif block_type == "raw": |  | ||||||
|                 new_lines.append(block_content.strip() + "\n") |  | ||||||
|             elif block_type == "math": |  | ||||||
|                 new_lines.append(f"$${block_content.strip()}$$\n") |  | ||||||
|             elif block_type == "comment": |  | ||||||
|                 new_lines.append(f"<!--{block_content.strip()}\n-->\n") |  | ||||||
|             elif block_type == "autofunction": |  | ||||||
|                 if block_info is not None: |  | ||||||
|                     new_lines.append(f"[[autodoc]] {block_info}\n") |  | ||||||
|             elif block_type == "autoclass": |  | ||||||
|                 if block_info is not None: |  | ||||||
|                     block = f"[[autodoc]] {block_info}\n" |  | ||||||
|                     options = parse_options(block_content) |  | ||||||
|                     if "special-members" in options: |  | ||||||
|                         special_members = options["special-members"].split(", ") |  | ||||||
|                         for special_member in special_members: |  | ||||||
|                             block += f"    - {special_member}\n" |  | ||||||
|                     if "members" in options: |  | ||||||
|                         members = options["members"] |  | ||||||
|                         if len(members) == 0: |  | ||||||
|                             block += "    - all\n" |  | ||||||
|                         else: |  | ||||||
|                             for member in members.split(", "): |  | ||||||
|                                 block += f"    - {member}\n" |  | ||||||
|                     new_lines.append(block) |  | ||||||
|             elif block_type == "image": |  | ||||||
|                 options = parse_options(block_content) |  | ||||||
|                 target = options.pop("target", None) |  | ||||||
|                 if block_info is not None: |  | ||||||
|                     options["src"] = block_info |  | ||||||
|                 else: |  | ||||||
|                     if target is None: |  | ||||||
|                         raise ValueError("Image source not defined.") |  | ||||||
|                     options["src"] = target |  | ||||||
|                 # Adapt path |  | ||||||
|                 options["src"] = options["src"].replace( |  | ||||||
|                     "/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/" |  | ||||||
|                 ) |  | ||||||
|                 html_code = " ".join( |  | ||||||
|                     [f'{key}="{value}"' for key, value in options.items()] |  | ||||||
|                 ) |  | ||||||
|                 new_lines.append(f"<img {html_code}/>\n") |  | ||||||
|  |  | ||||||
|             else: |  | ||||||
|                 new_lines.append( |  | ||||||
|                     f"{block_type},{block_info}\n{block_content.rstrip()}\n" |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|         else: |  | ||||||
|             new_lines.append(lines[idx]) |  | ||||||
|             idx += 1 |  | ||||||
|  |  | ||||||
|     return "\n".join(new_lines) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Re pattern that catches rst args blocks of the form `Parameters:`. |  | ||||||
| _re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$") |  | ||||||
| # Re pattern that catches return blocks of the form `Return:`. |  | ||||||
| _re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_return_line(line): |  | ||||||
|     """ |  | ||||||
|     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. |  | ||||||
|     """ |  | ||||||
|     splits_on_colon = line.split(":") |  | ||||||
|     idx = 1 |  | ||||||
|     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: |  | ||||||
|         idx += 2 |  | ||||||
|     if idx >= len(splits_on_colon): |  | ||||||
|         if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()): |  | ||||||
|             return line, "" |  | ||||||
|         return None, line |  | ||||||
|     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_raise_line(line): |  | ||||||
|     """ |  | ||||||
|     Split the raise line with format `SomeError some doc`. |  | ||||||
|     """ |  | ||||||
|     splits_on_colon = line.strip().split(" ") |  | ||||||
|     error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:]) |  | ||||||
|     if error_type and error_type[-1] == ":": |  | ||||||
|         error_type = error_type[:-1] |  | ||||||
|     return error_type, doc |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_arg_line(line): |  | ||||||
|     """ |  | ||||||
|     Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:. |  | ||||||
|     """ |  | ||||||
|     splits_on_colon = line.split(":") |  | ||||||
|     idx = 1 |  | ||||||
|     while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]: |  | ||||||
|         idx += 2 |  | ||||||
|     if idx >= len(splits_on_colon): |  | ||||||
|         return line, "" |  | ||||||
|     return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:]) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class InvalidRstDocstringError(ValueError): |  | ||||||
|     pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _re_parameters = re.compile( |  | ||||||
|     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL |  | ||||||
| ) |  | ||||||
| _re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_rst_docstring(docstring): |  | ||||||
|     """ |  | ||||||
|     Parses a docstring written in rst, in particular the list of arguments and the return type. |  | ||||||
|     """ |  | ||||||
|     lines = docstring.split("\n") |  | ||||||
|     idx = 0 |  | ||||||
|     while idx < len(lines): |  | ||||||
|         # Parameters section |  | ||||||
|         if _re_args.search(lines[idx]) is not None: |  | ||||||
|             # Title of the section. |  | ||||||
|             lines[idx] = "<parameters>\n" |  | ||||||
|             # Find the next nonempty line |  | ||||||
|             idx += 1 |  | ||||||
|             while is_empty_line(lines[idx]): |  | ||||||
|                 idx += 1 |  | ||||||
|             # Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the |  | ||||||
|             # Returns or Raises block. |  | ||||||
|             param_indent = find_indent(lines[idx]) |  | ||||||
|             while ( |  | ||||||
|                 idx < len(lines) |  | ||||||
|                 and find_indent(lines[idx]) == param_indent |  | ||||||
|                 and _re_returns.search(lines[idx]) is None |  | ||||||
|             ): |  | ||||||
|                 intro, doc = split_arg_line(lines[idx]) |  | ||||||
|                 # Line starting with a > after indent indicate a "section title" in the parameters. |  | ||||||
|                 if intro.lstrip().startswith(">"): |  | ||||||
|                     lines[idx] = intro.lstrip() |  | ||||||
|                 else: |  | ||||||
|                     lines[idx] = ( |  | ||||||
|                         re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc |  | ||||||
|                     ) |  | ||||||
|                 idx += 1 |  | ||||||
|                 while idx < len(lines) and ( |  | ||||||
|                     is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent |  | ||||||
|                 ): |  | ||||||
|                     idx += 1 |  | ||||||
|             lines.insert(idx, "</parameters>\n") |  | ||||||
|             idx += 1 |  | ||||||
|  |  | ||||||
|         # Returns section |  | ||||||
|         elif _re_returns.search(lines[idx]) is not None: |  | ||||||
|             # tag is either `return` or `yield` |  | ||||||
|             tag = _re_returns.match(lines[idx]).group(1).lower() |  | ||||||
|             # Title of the section. |  | ||||||
|             lines[idx] = f"<{tag}s>\n" |  | ||||||
|             # Find the next nonempty line |  | ||||||
|             idx += 1 |  | ||||||
|             while is_empty_line(lines[idx]): |  | ||||||
|                 idx += 1 |  | ||||||
|  |  | ||||||
|             # Grab the indent of the return line, this block will stop when we unindent under it. |  | ||||||
|             return_indent = find_indent(lines[idx]) |  | ||||||
|             raised_errors = [] |  | ||||||
|             # The line may contain the return type. |  | ||||||
|             if tag in ["return", "yield"]: |  | ||||||
|                 return_type, return_description = split_return_line(lines[idx]) |  | ||||||
|                 lines[idx] = return_description |  | ||||||
|                 idx += 1 |  | ||||||
|                 while idx < len(lines) and ( |  | ||||||
|                     is_empty_line(lines[idx]) |  | ||||||
|                     or find_indent(lines[idx]) >= return_indent |  | ||||||
|                 ): |  | ||||||
|                     idx += 1 |  | ||||||
|             else: |  | ||||||
|                 while idx < len(lines) and find_indent(lines[idx]) == return_indent: |  | ||||||
|                     return_type, return_description = split_raise_line(lines[idx]) |  | ||||||
|                     raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type) |  | ||||||
|                     lines[idx] = "- " + raised_error + " -- " + return_description |  | ||||||
|                     md_link = _re_md_link.match(raised_error) |  | ||||||
|                     if md_link: |  | ||||||
|                         raised_error = md_link[1] |  | ||||||
|                         raised_error = re.sub( |  | ||||||
|                             r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error |  | ||||||
|                         ) |  | ||||||
|                     if raised_error not in raised_errors: |  | ||||||
|                         raised_errors.append(raised_error) |  | ||||||
|                     idx += 1 |  | ||||||
|                     while idx < len(lines) and ( |  | ||||||
|                         is_empty_line(lines[idx]) |  | ||||||
|                         or find_indent(lines[idx]) > return_indent |  | ||||||
|                     ): |  | ||||||
|                         idx += 1 |  | ||||||
|  |  | ||||||
|             lines.insert(idx, f"</{tag}s>\n") |  | ||||||
|             idx += 1 |  | ||||||
|  |  | ||||||
|             # Return block finished, we insert the return type if one was specified |  | ||||||
|             if tag in ["return", "yield"] and return_type is not None: |  | ||||||
|                 lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n" |  | ||||||
|             elif len(raised_errors) > 0: |  | ||||||
|                 # raised errors |  | ||||||
|                 lines[ |  | ||||||
|                     idx - 1 |  | ||||||
|                 ] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n" |  | ||||||
|  |  | ||||||
|         else: |  | ||||||
|             idx += 1 |  | ||||||
|  |  | ||||||
|     result = "\n".join(lines) |  | ||||||
|  |  | ||||||
|     # combine multiple <parameters> blocks into one block |  | ||||||
|     if result.count("<parameters>") > 1: |  | ||||||
|         parameters_blocks = _re_parameters.findall(result) |  | ||||||
|         parameters_blocks = [pb[0].strip() for pb in parameters_blocks] |  | ||||||
|         parameters_str = "\n".join(parameters_blocks) |  | ||||||
|         result = _re_parameters.sub("", result) |  | ||||||
|         result += f"\n<parameters>{parameters_str}</parameters>\n" |  | ||||||
|  |  | ||||||
|     return result |  | ||||||
|  |  | ||||||
|  |  | ||||||
| _re_list = re.compile(r"^\s*(-|\*|\d+\.)\s") |  | ||||||
| _re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def remove_indent(text): |  | ||||||
|     """ |  | ||||||
|     Remove indents in text, except the one linked to lists (or sublists). |  | ||||||
|     """ |  | ||||||
|     lines = text.split("\n") |  | ||||||
|     # List of indents to remember for nested lists |  | ||||||
|     current_indents = [] |  | ||||||
|     # List of new indents to remember for nested lists |  | ||||||
|     new_indents = [] |  | ||||||
|     is_inside_code = False |  | ||||||
|     code_indent = 0 |  | ||||||
|     for idx, line in enumerate(lines): |  | ||||||
|         # Line is an item in a list. |  | ||||||
|         if _re_list.search(line) is not None: |  | ||||||
|             indent = find_indent(line) |  | ||||||
|             # Is it a new list / new level of nestedness? |  | ||||||
|             if len(current_indents) == 0 or indent > current_indents[-1]: |  | ||||||
|                 current_indents.append(indent) |  | ||||||
|                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] |  | ||||||
|                 lines[idx] = " " * new_indent + line[indent:] |  | ||||||
|                 new_indent += len(_re_list.search(line).groups()[0]) + 1 |  | ||||||
|                 new_indents.append(new_indent) |  | ||||||
|             # Otherwise it's an existing level of list (current one, or previous one) |  | ||||||
|             else: |  | ||||||
|                 # Let's find the proper level of indentation |  | ||||||
|                 level = len(current_indents) - 1 |  | ||||||
|                 while level >= 0 and current_indents[level] != indent: |  | ||||||
|                     level -= 1 |  | ||||||
|                 current_indents = current_indents[: level + 1] |  | ||||||
|                 new_indents = new_indents[:level] |  | ||||||
|                 new_indent = 0 if len(new_indents) == 0 else new_indents[-1] |  | ||||||
|                 lines[idx] = " " * new_indent + line[indent:] |  | ||||||
|                 new_indent += len(_re_list.search(line).groups()[0]) + 1 |  | ||||||
|                 new_indents.append(new_indent) |  | ||||||
|  |  | ||||||
|         # Line is an autodoc, we keep the indent for the list just after if there is one. |  | ||||||
|         elif _re_autodoc.search(line) is not None: |  | ||||||
|             indent = find_indent(line) |  | ||||||
|             current_indents = [indent] |  | ||||||
|             new_indents = [4] |  | ||||||
|             lines[idx] = line.strip() |  | ||||||
|  |  | ||||||
|         # Deal with empty lines separately |  | ||||||
|         elif is_empty_line(line): |  | ||||||
|             lines[idx] = "" |  | ||||||
|  |  | ||||||
|         # Code blocks |  | ||||||
|         elif line.lstrip().startswith("```"): |  | ||||||
|             is_inside_code = not is_inside_code |  | ||||||
|             if is_inside_code: |  | ||||||
|                 code_indent = find_indent(line) |  | ||||||
|             lines[idx] = line[code_indent:] |  | ||||||
|         elif is_inside_code: |  | ||||||
|             lines[idx] = line[code_indent:] |  | ||||||
|  |  | ||||||
|         else: |  | ||||||
|             indent = find_indent(line) |  | ||||||
|             if len(current_indents) > 0 and indent > current_indents[-1]: |  | ||||||
|                 lines[idx] = " " * new_indents[-1] + line[indent:] |  | ||||||
|             elif len(current_indents) > 0: |  | ||||||
|                 # Let's find the proper level of indentation |  | ||||||
|                 level = len(current_indents) - 1 |  | ||||||
|                 while level >= 0 and current_indents[level] > indent: |  | ||||||
|                     level -= 1 |  | ||||||
|                 current_indents = current_indents[: level + 1] |  | ||||||
|                 if level >= 0: |  | ||||||
|                     if current_indents[level] < indent: |  | ||||||
|                         new_indents = new_indents[: level + 1] |  | ||||||
|                     else: |  | ||||||
|                         new_indents = new_indents[:level] |  | ||||||
|                     new_indent = 0 if len(new_indents) == 0 else new_indents[-1] |  | ||||||
|                     lines[idx] = " " * new_indent + line[indent:] |  | ||||||
|                     new_indents.append(new_indent) |  | ||||||
|                 else: |  | ||||||
|                     new_indents = [] |  | ||||||
|                     lines[idx] = line[indent:] |  | ||||||
|             else: |  | ||||||
|                 lines[idx] = line[indent:] |  | ||||||
|  |  | ||||||
|     return "\n".join(lines) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def base_rst_to_mdx(text, page_info, unindent=True): |  | ||||||
|     """ |  | ||||||
|     Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs. |  | ||||||
|     """ |  | ||||||
|     text = convert_rst_links(text, page_info) |  | ||||||
|     text = convert_special_chars(text) |  | ||||||
|     text = convert_rst_blocks(text, page_info) |  | ||||||
|     # Convert * in lists to - to avoid the formatting conversion treat them as bold. |  | ||||||
|     text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE) |  | ||||||
|     text = convert_rst_formatting(text) |  | ||||||
|     return remove_indent(text) if unindent else text |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_rst_docstring_to_mdx(docstring, page_info): |  | ||||||
|     """ |  | ||||||
|     Convert a docstring written in rst to mdx. |  | ||||||
|     """ |  | ||||||
|     text = parse_rst_docstring(docstring) |  | ||||||
|     return base_rst_to_mdx(text, page_info) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def process_titles(lines): |  | ||||||
|     """Converts rst titles to markdown titles.""" |  | ||||||
|     title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ") |  | ||||||
|     title_levels = {} |  | ||||||
|     new_lines = [] |  | ||||||
|     for line in lines: |  | ||||||
|         if ( |  | ||||||
|             len(new_lines) > 0 |  | ||||||
|             and len(line) >= len(new_lines[-1]) |  | ||||||
|             and len(set(line)) == 1 |  | ||||||
|             and line[0] in title_chars |  | ||||||
|             and line != "::" |  | ||||||
|         ): |  | ||||||
|             char = line[0] |  | ||||||
|             level = title_levels.get(char, len(title_levels) + 1) |  | ||||||
|             if level not in title_levels: |  | ||||||
|                 title_levels[char] = level |  | ||||||
|             new_lines[-1] = f"{'#' * level} {new_lines[-1]}" |  | ||||||
|         else: |  | ||||||
|             new_lines.append(line) |  | ||||||
|     return new_lines |  | ||||||
|  |  | ||||||
|  |  | ||||||
| # Matches lines with a pattern of a table new line in rst. |  | ||||||
| _re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$") |  | ||||||
| # Matches lines with a pattern of a table new line in rst, with a first column empty. |  | ||||||
| _re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$") |  | ||||||
| # Matches lines with a pattern of a first table line in rst. |  | ||||||
| _re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$") |  | ||||||
| # Re pattern that catches anchors of the type .. reference: |  | ||||||
| _re_anchor_section = re.compile(r"^\.\.\s+_(\S+):") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_pt_tf_code_blocks(text): |  | ||||||
|     """ |  | ||||||
|     Split PyTorch and TensorFlow specific block codes. |  | ||||||
|     """ |  | ||||||
|     lines = text.split("\n") |  | ||||||
|     new_lines = [] |  | ||||||
|     idx = 0 |  | ||||||
|     while idx < len(lines): |  | ||||||
|         if lines[idx].startswith("```"): |  | ||||||
|             code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []} |  | ||||||
|             is_pytorch = False |  | ||||||
|             is_tensorflow = False |  | ||||||
|             idx += 1 |  | ||||||
|             while idx < len(lines) and lines[idx].strip() != "```": |  | ||||||
|                 if "## PYTORCH CODE" in lines[idx]: |  | ||||||
|                     is_pytorch = True |  | ||||||
|                     is_tensorflow = False |  | ||||||
|                 elif "## TENSORFLOW CODE" in lines[idx]: |  | ||||||
|                     is_tensorflow = True |  | ||||||
|                     is_pytorch = False |  | ||||||
|                 elif is_pytorch: |  | ||||||
|                     code_lines["pytorch"].append(lines[idx]) |  | ||||||
|                 elif is_tensorflow: |  | ||||||
|                     code_lines["tensorflow"].append(lines[idx]) |  | ||||||
|                 else: |  | ||||||
|                     code_lines["common"].append(lines[idx]) |  | ||||||
|                 idx += 1 |  | ||||||
|             if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0: |  | ||||||
|                 block_lines = ["<frameworkcontent>", "<pt>"] |  | ||||||
|                 block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"]) |  | ||||||
|                 block_lines.extend(["```", "</pt>", "<tf>"]) |  | ||||||
|                 block_lines.extend( |  | ||||||
|                     code_lines["common"].copy() + code_lines["tensorflow"] |  | ||||||
|                 ) |  | ||||||
|                 block_lines.extend(["```", "</tf>", "</frameworkcontent>"]) |  | ||||||
|                 new_lines.extend(block_lines) |  | ||||||
|             else: |  | ||||||
|                 block_lines = code_lines["common"] + ["```"] |  | ||||||
|                 new_lines.extend(block_lines) |  | ||||||
|             idx += 1 |  | ||||||
|         else: |  | ||||||
|             new_lines.append(lines[idx]) |  | ||||||
|             idx += 1 |  | ||||||
|     return "\n".join(new_lines) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_rst_to_mdx(rst_text, page_info, add_imports=True): |  | ||||||
|     """ |  | ||||||
|     Convert a document written in rst to mdx. |  | ||||||
|     """ |  | ||||||
|     lines = rst_text.split("\n") |  | ||||||
|     lines = process_titles(lines) |  | ||||||
|     if add_imports: |  | ||||||
|         new_lines = [ |  | ||||||
|             '<script lang="ts">', |  | ||||||
|             '	import Tip from "$lib/Tip.svelte";', |  | ||||||
|             '	import Youtube from "$lib/Youtube.svelte";', |  | ||||||
|             '	import Docstring from "$lib/Docstring.svelte";', |  | ||||||
|             '	import CodeBlock from "$lib/CodeBlock.svelte";', |  | ||||||
|             '	import CodeBlockFw from "$lib/CodeBlockFw.svelte";', |  | ||||||
|             '	import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";', |  | ||||||
|             '	import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";', |  | ||||||
|             '	import IconCopyLink from "$lib/IconCopyLink.svelte";', |  | ||||||
|             '	import FrameworkContent from "$lib/FrameworkContent.svelte";', |  | ||||||
|             '	import Markdown from "$lib/Markdown.svelte";', |  | ||||||
|             '	import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";', |  | ||||||
|             '	import Added from "$lib/Added.svelte";', |  | ||||||
|             '	import Changed from "$lib/Changed.svelte";', |  | ||||||
|             '	import Deprecated from "$lib/Deprecated.svelte";', |  | ||||||
|             '	import PipelineIcon from "$lib/PipelineIcon.svelte";', |  | ||||||
|             '	import PipelineTag from "$lib/PipelineTag.svelte";', |  | ||||||
|             "	", |  | ||||||
|             '	export let fw: "pt" | "tf"', |  | ||||||
|             "</script>", |  | ||||||
|             "<svelte:head>", |  | ||||||
|             '<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >', |  | ||||||
|             "</svelte:head>", |  | ||||||
|             "", |  | ||||||
|         ] |  | ||||||
|     else: |  | ||||||
|         new_lines = [] |  | ||||||
|     for line in lines: |  | ||||||
|         if _re_ignore_line_table.search(line) is not None: |  | ||||||
|             continue |  | ||||||
|         elif _re_ignore_line_table1.search(line) is not None: |  | ||||||
|             continue |  | ||||||
|         elif _re_sep_line_table.search(line) is not None: |  | ||||||
|             line = line.replace("=", "-").replace("+", "|") |  | ||||||
|         elif _re_anchor_section.search(line) is not None: |  | ||||||
|             anchor_name = _re_anchor_section.search(line).groups()[0] |  | ||||||
|             line = f"<a id='{anchor_name}'></a>" |  | ||||||
|         new_lines.append(line) |  | ||||||
|     text = "\n".join(new_lines) |  | ||||||
|  |  | ||||||
|     return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info)) |  | ||||||
| @ -8,9 +8,6 @@ from kernels.compat import tomllib | |||||||
| from kernels.lockfile import KernelLock, get_kernel_locks | from kernels.lockfile import KernelLock, get_kernel_locks | ||||||
| from kernels.utils import install_kernel, install_kernel_all_variants | from kernels.utils import install_kernel, install_kernel_all_variants | ||||||
|  |  | ||||||
| from .doc import generate_readme_for_kernel |  | ||||||
| from .wheel import build_variant_to_wheel |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     parser = argparse.ArgumentParser( |     parser = argparse.ArgumentParser( | ||||||
| @ -39,47 +36,6 @@ def main(): | |||||||
|     ) |     ) | ||||||
|     lock_parser.set_defaults(func=lock_kernels) |     lock_parser.set_defaults(func=lock_kernels) | ||||||
|  |  | ||||||
|     to_wheel_parser = subparsers.add_parser( |  | ||||||
|         "to-wheel", help="Convert a kernel to a wheel file" |  | ||||||
|     ) |  | ||||||
|     to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID") |  | ||||||
|     to_wheel_parser.add_argument("version", type=str, help="The kernel version") |  | ||||||
|     to_wheel_parser.add_argument( |  | ||||||
|         "--python-version", |  | ||||||
|         type=str, |  | ||||||
|         default="3.9", |  | ||||||
|         help="The minimum Python version. Must match the Python version that the kernel was compiled for.", |  | ||||||
|     ) |  | ||||||
|     to_wheel_parser.add_argument( |  | ||||||
|         "--manylinux-version", |  | ||||||
|         type=str, |  | ||||||
|         default="2.28", |  | ||||||
|         help="The manylinux version. Must match the manylinux version that the kernel was compiled for.", |  | ||||||
|     ) |  | ||||||
|     to_wheel_parser.set_defaults(func=kernels_to_wheel) |  | ||||||
|  |  | ||||||
|     # Add generate-readme subcommand parser |  | ||||||
|     generate_readme_parser = subparsers.add_parser( |  | ||||||
|         "generate-readme", |  | ||||||
|         help="Generate README snippets for a kernel's public functions", |  | ||||||
|     ) |  | ||||||
|     generate_readme_parser.add_argument( |  | ||||||
|         "repo_id", |  | ||||||
|         type=str, |  | ||||||
|         help="The kernel repo ID (e.g., kernels-community/activation)", |  | ||||||
|     ) |  | ||||||
|     generate_readme_parser.add_argument( |  | ||||||
|         "--revision", |  | ||||||
|         type=str, |  | ||||||
|         default="main", |  | ||||||
|         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", |  | ||||||
|     ) |  | ||||||
|     generate_readme_parser.set_defaults( |  | ||||||
|         func=lambda args: generate_readme_for_kernel( |  | ||||||
|             repo_id=args.repo_id, revision=args.revision |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     args.func(args) |     args.func(args) | ||||||
|  |  | ||||||
| @ -121,24 +77,6 @@ def download_kernels(args): | |||||||
|         sys.exit(1) |         sys.exit(1) | ||||||
|  |  | ||||||
|  |  | ||||||
| def kernels_to_wheel(args): |  | ||||||
|     variants_path = install_kernel_all_variants( |  | ||||||
|         repo_id=args.repo_id, revision=f"v{args.version}" |  | ||||||
|     ) |  | ||||||
|     for variant_path in variants_path.iterdir(): |  | ||||||
|         if not variant_path.is_dir(): |  | ||||||
|             continue |  | ||||||
|         wheel_path = build_variant_to_wheel( |  | ||||||
|             manylinux_version=args.manylinux_version, |  | ||||||
|             python_version=args.python_version, |  | ||||||
|             repo_id=args.repo_id, |  | ||||||
|             version=args.version, |  | ||||||
|             variant_path=variant_path, |  | ||||||
|             wheel_dir=Path("."), |  | ||||||
|         ) |  | ||||||
|         print(f"☸️ {wheel_path.name}", file=sys.stderr) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def lock_kernels(args): | def lock_kernels(args): | ||||||
|     with open(args.project_dir / "pyproject.toml", "rb") as f: |     with open(args.project_dir / "pyproject.toml", "rb") as f: | ||||||
|         data = tomllib.load(f) |         data = tomllib.load(f) | ||||||
|  | |||||||
| @ -1,242 +0,0 @@ | |||||||
| import inspect |  | ||||||
| import re |  | ||||||
| import sys |  | ||||||
| from types import ModuleType |  | ||||||
|  |  | ||||||
| import yaml |  | ||||||
|  |  | ||||||
| from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx |  | ||||||
| from .utils import get_kernel |  | ||||||
|  |  | ||||||
| _RE_PARAMETERS = re.compile( |  | ||||||
|     r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL |  | ||||||
| ) |  | ||||||
| _RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL) |  | ||||||
| _RE_RETURNTYPE = re.compile( |  | ||||||
|     r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _extract_description_before_tags(docstring_mdx: str) -> str: |  | ||||||
|     """Extract the description part of a docstring before any tags.""" |  | ||||||
|     params_pos = docstring_mdx.find("<parameters>") |  | ||||||
|     returns_pos = docstring_mdx.find("<returns>") |  | ||||||
|     returntype_pos = docstring_mdx.find("<returntype>") |  | ||||||
|     positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1] |  | ||||||
|  |  | ||||||
|     if positions: |  | ||||||
|         first_tag_pos = min(positions) |  | ||||||
|         return docstring_mdx[:first_tag_pos].strip() |  | ||||||
|     else: |  | ||||||
|         return docstring_mdx.strip() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None: |  | ||||||
|     """Print the parameters section from a docstring.""" |  | ||||||
|     matches = _RE_PARAMETERS.findall(docstring_mdx) |  | ||||||
|     if matches: |  | ||||||
|         header = "#" * header_level |  | ||||||
|         print(f"\n{header} Parameters") |  | ||||||
|         for match in matches: |  | ||||||
|             print(f"\n{match[0].strip()}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _print_returns_section( |  | ||||||
|     docstring_mdx: str, *, context_name: str, header_level: int |  | ||||||
| ) -> None: |  | ||||||
|     """Print the returns section from a docstring.""" |  | ||||||
|     return_matches = _RE_RETURNS.findall(docstring_mdx) |  | ||||||
|     returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx) |  | ||||||
|  |  | ||||||
|     if return_matches or returntype_matches: |  | ||||||
|         header = "#" * header_level |  | ||||||
|         print(f"\n{header} Returns") |  | ||||||
|  |  | ||||||
|         if returntype_matches: |  | ||||||
|             if len(returntype_matches) > 1: |  | ||||||
|                 raise ValueError( |  | ||||||
|                     f"More than one <returntype> tag found in docstring for {context_name}" |  | ||||||
|                 ) |  | ||||||
|             print(f"\n**Type**: {returntype_matches[0][0].strip()}") |  | ||||||
|  |  | ||||||
|         if return_matches: |  | ||||||
|             for match in return_matches: |  | ||||||
|                 print(f"\n{match[0].strip()}") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_docstring(obj, use_dict_check: bool = False) -> str: |  | ||||||
|     """Get docstring from an object, with fallback to default message.""" |  | ||||||
|     # Check whether the class/method itself has docs and not just |  | ||||||
|     # the superclass. |  | ||||||
|     if use_dict_check: |  | ||||||
|         has_doc = obj.__dict__.get("__doc__", None) is not None |  | ||||||
|     else: |  | ||||||
|         has_doc = getattr(obj, "__doc__", None) is not None |  | ||||||
|  |  | ||||||
|     # We use inspect.getdoc because it does normalization. |  | ||||||
|     doc = inspect.getdoc(obj) |  | ||||||
|  |  | ||||||
|     return doc if has_doc and doc is not None else "No documentation available." |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _process_and_print_docstring( |  | ||||||
|     docstring: str, *, kernel_name: str, context_name: str, header_level: int |  | ||||||
| ) -> None: |  | ||||||
|     """Convert docstring to MDX and print description, parameters, and returns sections.""" |  | ||||||
|     docstring_mdx = convert_rst_docstring_to_mdx( |  | ||||||
|         docstring, page_info={"package_name": kernel_name} |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     # Print the description |  | ||||||
|     description = _extract_description_before_tags(docstring_mdx) |  | ||||||
|     print(f"\n{description}") |  | ||||||
|  |  | ||||||
|     # Print parameters and returns sections |  | ||||||
|     _print_parameters_section(docstring_mdx, header_level=header_level) |  | ||||||
|     _print_returns_section( |  | ||||||
|         docstring_mdx, context_name=context_name, header_level=header_level |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None: |  | ||||||
|     kernel_module = get_kernel(repo_id=repo_id, revision=revision) |  | ||||||
|     kernel_name = repo_id.split("/")[-1].replace("-", "_") |  | ||||||
|  |  | ||||||
|     generate_metadata(kernel_module) |  | ||||||
|     generate_kernel_doc(kernel_module, kernel_name) |  | ||||||
|     generate_function_doc(kernel_module, kernel_name) |  | ||||||
|     generate_layers_doc(kernel_module, kernel_name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_metadata(module: ModuleType) -> None: |  | ||||||
|     metadata = getattr(module, "__kernel_metadata__", {}) |  | ||||||
|     if "tags" not in metadata: |  | ||||||
|         metadata["tags"] = ["kernel"] |  | ||||||
|     else: |  | ||||||
|         if "kernel" not in metadata["tags"]: |  | ||||||
|             metadata["tags"].append("kernel") |  | ||||||
|  |  | ||||||
|     print("---") |  | ||||||
|     print(yaml.dump(metadata), end="") |  | ||||||
|     print("---") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None: |  | ||||||
|     docstring = module.__doc__.strip() if module.__doc__ is not None else None |  | ||||||
|     if docstring: |  | ||||||
|         title, rest = docstring.split("\n", 1) |  | ||||||
|         print(f"# {title.strip()}") |  | ||||||
|         print( |  | ||||||
|             f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None: |  | ||||||
|     print("\n## Functions") |  | ||||||
|  |  | ||||||
|     # Track if we found any functions |  | ||||||
|     found_functions = False |  | ||||||
|  |  | ||||||
|     for name, func in inspect.getmembers(kernel_module, inspect.isfunction): |  | ||||||
|         # Do not include imported functions. |  | ||||||
|         if func.__module__ != kernel_module.__name__: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         # Exclude private functions. |  | ||||||
|         if name.startswith("_"): |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         found_functions = True |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             sig = inspect.signature(func) |  | ||||||
|             docstring = _get_docstring(func) |  | ||||||
|         except ValueError: |  | ||||||
|             print( |  | ||||||
|                 f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}", |  | ||||||
|                 file=sys.stderr, |  | ||||||
|             ) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         print(f"\n### Function `{name}`") |  | ||||||
|         print(f"\n`{sig}`") |  | ||||||
|  |  | ||||||
|         _process_and_print_docstring( |  | ||||||
|             docstring, kernel_name=kernel_name, context_name=name, header_level=3 |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if not found_functions: |  | ||||||
|         print("\nNo public top-level functions.") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None: |  | ||||||
|     # Check if layers module is available |  | ||||||
|     layers_module = getattr(kernel_module, "layers", None) |  | ||||||
|     if layers_module is None: |  | ||||||
|         return |  | ||||||
|  |  | ||||||
|     print("\n## Layers") |  | ||||||
|  |  | ||||||
|     # Track if we found any classes |  | ||||||
|     found_classes = False |  | ||||||
|  |  | ||||||
|     for class_name, cls in inspect.getmembers(layers_module, inspect.isclass): |  | ||||||
|         # Exclude classes that were imported. |  | ||||||
|         if cls.__module__ != layers_module.__name__: |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         found_classes = True |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             # Get docstring, but not from superclasses. |  | ||||||
|             class_docstring = _get_docstring(cls, use_dict_check=True) |  | ||||||
|         except Exception: |  | ||||||
|             print( |  | ||||||
|                 f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}", |  | ||||||
|                 file=sys.stderr, |  | ||||||
|             ) |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         print(f"\n### Class `{class_name}`") |  | ||||||
|  |  | ||||||
|         # Always print class description (helper handles conversion and formatting) |  | ||||||
|         class_docstring_mdx = convert_rst_docstring_to_mdx( |  | ||||||
|             class_docstring, page_info={"package_name": kernel_name} |  | ||||||
|         ) |  | ||||||
|         description = _extract_description_before_tags(class_docstring_mdx) |  | ||||||
|         print(f"\n{description}") |  | ||||||
|  |  | ||||||
|         # Document methods |  | ||||||
|         print("\n#### Methods") |  | ||||||
|  |  | ||||||
|         for method_name, method in inspect.getmembers(cls, inspect.isfunction): |  | ||||||
|             # Note: also skip __init__, since extension layers cannot have a constructor. |  | ||||||
|             if method_name.startswith("_"): |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             # Skip methods from superclasses. |  | ||||||
|             if method_name not in cls.__dict__: |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             try: |  | ||||||
|                 sig = inspect.signature(method) |  | ||||||
|                 method_docstring = _get_docstring(method) |  | ||||||
|             except ValueError: |  | ||||||
|                 print( |  | ||||||
|                     f"Warning: Could not retrieve signature for {method_name} in {class_name}", |  | ||||||
|                     file=sys.stderr, |  | ||||||
|                 ) |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             print(f"\n##### Method `{method_name}`") |  | ||||||
|             print(f"\n`{sig}`") |  | ||||||
|  |  | ||||||
|             _process_and_print_docstring( |  | ||||||
|                 method_docstring, |  | ||||||
|                 kernel_name=kernel_name, |  | ||||||
|                 context_name=method_name, |  | ||||||
|                 header_level=6, |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|     if not found_classes: |  | ||||||
|         print("\nNo layers defined.") |  | ||||||
| @ -1,67 +1,19 @@ | |||||||
| from __future__ import annotations |  | ||||||
|  |  | ||||||
| import inspect | import inspect | ||||||
| import os | import os | ||||||
| import warnings | import warnings | ||||||
| from contextvars import ContextVar | from contextvars import ContextVar | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from dataclasses import dataclass, field | from dataclasses import dataclass, field | ||||||
| from enum import Flag, auto | from typing import TYPE_CHECKING, Dict, Union | ||||||
| from types import MethodType |  | ||||||
| from typing import ( |  | ||||||
|     TYPE_CHECKING, |  | ||||||
|     Dict, |  | ||||||
|     Optional, |  | ||||||
|     Tuple, |  | ||||||
|     Type, |  | ||||||
|     Union, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| from .utils import get_kernel | from .utils import get_kernel | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     import torch |  | ||||||
|     from torch import nn |     from torch import nn | ||||||
|  |  | ||||||
|  |  | ||||||
| _DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0"))) | _DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0"))) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Mode(Flag): |  | ||||||
|     """ |  | ||||||
|     Kernelize mode |  | ||||||
|  |  | ||||||
|     The `Mode` flag is used by `kernelize` to select kernels for the given |  | ||||||
|     mode. Mappings can be registered for specific modes. |  | ||||||
|  |  | ||||||
|     * `INFERENCE`: The kernel is used for inference. |  | ||||||
|     * `TRAINING`: The kernel is used for training. |  | ||||||
|     * `TORCH_COMPILE`: The kernel is used with `torch.compile`. |  | ||||||
|     * `DEFAULT`: In a kernel mapping, this kernel is used when no other mode |  | ||||||
|        matches. |  | ||||||
|  |  | ||||||
|     Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE` |  | ||||||
|     should be used for layers that are used for inference *with* `torch.compile`. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     _NONE = 0 |  | ||||||
|     DEFAULT = auto() |  | ||||||
|     TRAINING = auto() |  | ||||||
|     INFERENCE = auto() |  | ||||||
|     TORCH_COMPILE = auto() |  | ||||||
|  |  | ||||||
|     def __or__(self, other: Mode) -> Mode: |  | ||||||
|         union = super().__or__(other) |  | ||||||
|  |  | ||||||
|         if Mode.INFERENCE in union and Mode.TRAINING in union: |  | ||||||
|             raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.") |  | ||||||
|  |  | ||||||
|         if Mode.DEFAULT in union and union != Mode.DEFAULT: |  | ||||||
|             raise ValueError("Mode.DEFAULT cannot be combined with other modes.") |  | ||||||
|  |  | ||||||
|         return union |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass(frozen=True) | @dataclass(frozen=True) | ||||||
| class Device: | class Device: | ||||||
|     type: str |     type: str | ||||||
| @ -101,28 +53,29 @@ class LayerRepository: | |||||||
|         return hash((self.layer_name, self.repo_id, self.revision)) |         return hash((self.layer_name, self.repo_id, self.revision)) | ||||||
|  |  | ||||||
|  |  | ||||||
| _CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {} | _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar( | ||||||
|  |     "_KERNEL_MAPPING", default={} | ||||||
|  |  | ||||||
| _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = ( |  | ||||||
|     ContextVar("_KERNEL_MAPPING", default={}) |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def use_kernel_mapping( | def use_kernel_mapping( | ||||||
|     mapping: Dict[ |     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]], | ||||||
|         str, |  | ||||||
|         Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]], |  | ||||||
|     ], |  | ||||||
|     *, |     *, | ||||||
|     inherit_mapping: bool = True, |     inherit_mapping: bool = True, | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     Context manager that sets a mapping for a duration of the context. |     Context manager that sets a kernel mapping for the duration of the context. | ||||||
|  |  | ||||||
|     When `inherit_mapping` is set to `True` the current mapping will be |     Args: | ||||||
|     extended by `mapping` inside the context. If it is `False`, only |         mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`): | ||||||
|     `mapping` is used inside the context. |             A mapping between layer names and their corresponding kernel repositories. | ||||||
|  |         inherit_mapping (`bool`, *optional*, defaults to `True`): | ||||||
|  |             The current mapping will be extended by `mapping` when set to `True`. | ||||||
|  |             When set to `False`, the current mapping will be replaced by `mapping` | ||||||
|  |             for the duration of the context. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ContextManager`: Context manager that sets up the mapping. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     class ContextManager: |     class ContextManager: | ||||||
| @ -141,137 +94,72 @@ def use_kernel_mapping( | |||||||
|  |  | ||||||
|  |  | ||||||
| def register_kernel_mapping( | def register_kernel_mapping( | ||||||
|     mapping: Dict[ |     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]], | ||||||
|         str, |  | ||||||
|         Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]], |  | ||||||
|     ], |  | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     Allows one to register a mapping between a layer name and the corresponding |     Register a mapping between a layer name the corresponding kernel to use, depending on the device. | ||||||
|     kernel(s) to use, depending on the device. This should be used in conjunction |     This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname. | ||||||
|     with `kernelize`. |  | ||||||
|  |  | ||||||
|     Example usage: |     Args: | ||||||
|  |         mapping (`Dict[str, Dict[Union[Device, str], LayerRepository]]`): | ||||||
|  |             A mapping between layer names and their corresponding kernel repositories. | ||||||
|  |  | ||||||
|     ```python |     Example: | ||||||
|     from kernels import LayerRepository, register_kernel_mapping |         ```python | ||||||
|  |         from kernels import LayerRepository, register_kernel_mapping | ||||||
|  |  | ||||||
|     kernel_layer_mapping = { |         kernel_layer_mapping = { | ||||||
|       "LlamaRMSNorm": { |         "LlamaRMSNorm": { | ||||||
|           "cuda": LayerRepository( |             "cuda": LayerRepository( | ||||||
|               repo_id="kernels-community/activation", |                 repo_id="kernels-community/activation", | ||||||
|               layer_name="RmsNorm", |                 layer_name="RmsNorm", | ||||||
|               revision="layers", |                 revision="layers", | ||||||
|           ), |             ), | ||||||
|       }, |         }, | ||||||
|     } |         } | ||||||
|     register_kernel_mapping(kernel_layer_mapping) |         register_kernel_mapping(kernel_layer_mapping) | ||||||
|     ``` |         ``` | ||||||
|     """ |     """ | ||||||
|     # Merge with existing mappings. |     # Merge with existing mappings. | ||||||
|     for new_kernel, new_device_repos in mapping.items(): |     for new_kernel, new_device_repos in mapping.items(): | ||||||
|         device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {}) |         device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {}) | ||||||
|         for new_device, new_repo in new_device_repos.items(): |         for new_device, new_repo in new_device_repos.items(): | ||||||
|             device = ( |             if isinstance(new_device, str): | ||||||
|                 Device(type=new_device) if isinstance(new_device, str) else new_device |                 device_repo[Device(type=new_device)] = new_repo | ||||||
|             ) |  | ||||||
|  |  | ||||||
|             if isinstance(new_repo, LayerRepository): |  | ||||||
|                 kernel_options = {Mode.DEFAULT: new_repo} |  | ||||||
|             else: |             else: | ||||||
|                 kernel_options = new_repo |                 device_repo[new_device] = new_repo | ||||||
|  |  | ||||||
|             device_repo[device] = kernel_options |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def replace_kernel_forward_from_hub( | def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True): | ||||||
|     cls, |  | ||||||
|     layer_name: str, |  | ||||||
| ): |  | ||||||
|     """ |     """ | ||||||
|     Decorator that prepares a layer class to use a kernel from the Hugging Face Hub. |     Replace the forward function of a layer using a layer from the kernel hub. | ||||||
|  |     This function monkeypatches a layer, replacing the `forward` method | ||||||
|     This decorator stores the layer name and original forward method, which will be used |     of the layer with that of a layer from the hub. The replacement is done | ||||||
|     by the kernelize function to replace the forward implementation with the appropriate |     when a layer matching `layer_name` and device type is registered through | ||||||
|     kernel from the hub. |     [`register_layer_mapping`]. The device type is inferred from the first | ||||||
|  |     argument to `forward`. | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
|         cls: The layer class to decorate |         cls (`nn.Module`): | ||||||
|         layer_name: The name of the layer to use for kernel lookup |             The layer class to replace the forward function of. | ||||||
|  |         layer_name (`str`): | ||||||
|  |             The name to assign to the layer. | ||||||
|  |         use_fallback (`bool`, *optional*, defaults to `True`): | ||||||
|  |             Whether to use the fallback forward function if no kernel mapping | ||||||
|  |             is found. If set to `False`, a `ValueError` will be raised if no kernel | ||||||
|  |             mapping is found. | ||||||
|     """ |     """ | ||||||
|     cls.kernel_layer_name = layer_name |  | ||||||
|  |  | ||||||
|  |     fallback_forward = cls.forward | ||||||
|  |  | ||||||
| def _select_repository( |     cached_layer: Dict[LayerRepository, nn.Module] = {} | ||||||
|     repositories: Dict[Mode, LayerRepository], |  | ||||||
|     *, |  | ||||||
|     mode: Mode, |  | ||||||
| ) -> Optional[Tuple[LayerRepository, Mode]]: |  | ||||||
|     if mode in repositories: |  | ||||||
|         return (repositories[mode], mode) |  | ||||||
|     elif Mode.DEFAULT in repositories: |  | ||||||
|         return (repositories[Mode.DEFAULT], Mode.DEFAULT) |  | ||||||
|     else: |  | ||||||
|         return None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def kernelize( |  | ||||||
|     model: "nn.Module", |  | ||||||
|     *, |  | ||||||
|     mode: Mode, |  | ||||||
|     device: Optional[Union[str, "torch.device"]] = None, |  | ||||||
|     use_fallback: bool = True, |  | ||||||
| ): |  | ||||||
|     """ |  | ||||||
|     Iterate over all modules in the model and replace the `forward` method of |  | ||||||
|     extensible layers for which kernels are registered using `register_kernel_mapping` |  | ||||||
|     or `use_kernel_mapping`. |  | ||||||
|  |  | ||||||
|     Args: |  | ||||||
|         model: The PyTorch model to kernelize |  | ||||||
|         mode: the mode that the kernel is going to be used in (e.g. |  | ||||||
|             `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training |  | ||||||
|             and `torch.compile`). |  | ||||||
|         device: The device type to load kernels for. The device type will be inferred |  | ||||||
|             from the parameters of the model when not provided. |  | ||||||
|         use_fallback: Whether to use the original forward method of modules when no |  | ||||||
|             compatible kernel could be found. If set to `False`, an exception will |  | ||||||
|             be raised in such cases. |  | ||||||
|  |  | ||||||
|     Returns: |  | ||||||
|         The kernelized model |  | ||||||
|     """ |  | ||||||
|     import torch |  | ||||||
|  |  | ||||||
|     if mode == Mode.DEFAULT: |  | ||||||
|         raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.") |  | ||||||
|  |  | ||||||
|     # Type check ignored because this causes a false negative on Python < 3.11. |  | ||||||
|     # Looks similar to: https://github.com/python/mypy/issues/9642 |  | ||||||
|     # Remove once we start doing typing checks on >= 3.11. |  | ||||||
|     if Mode.INFERENCE not in mode and Mode.TRAINING not in mode:  # type: ignore[operator] |  | ||||||
|         raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.") |  | ||||||
|  |  | ||||||
|     if device is None: |  | ||||||
|         device_type = _find_device(model) |  | ||||||
|     elif isinstance(device, str): |  | ||||||
|         device_type = Device(type=torch.device(device).type) |  | ||||||
|     else: |  | ||||||
|         device_type = Device(device.type) |  | ||||||
|     assert isinstance(device_type, Device) |  | ||||||
|  |  | ||||||
|     for _, module in model.named_modules(): |  | ||||||
|         module_class = type(module) |  | ||||||
|         if not hasattr(module_class, "kernel_layer_name"): |  | ||||||
|             continue |  | ||||||
|         layer_name = module_class.kernel_layer_name |  | ||||||
|  |  | ||||||
|  |     def forward(self, x, *args, **kwargs): | ||||||
|         if _DISABLE_KERNEL_MAPPING: |         if _DISABLE_KERNEL_MAPPING: | ||||||
|             _replace_forward(module, module_class) |             return fallback_forward(self, x, *args, **kwargs) | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         kernel = _KERNEL_MAPPING.get().get(str(layer_name)) |  | ||||||
|  |  | ||||||
|  |         needs_backward = self.training | ||||||
|  |         kernel = _KERNEL_MAPPING.get().get(layer_name) | ||||||
|         if kernel is None: |         if kernel is None: | ||||||
|             warnings.warn( |             warnings.warn( | ||||||
|                 "\n" |                 "\n" | ||||||
| @ -281,70 +169,91 @@ def kernelize( | |||||||
|             ) |             ) | ||||||
|             if not use_fallback: |             if not use_fallback: | ||||||
|                 raise ValueError(f"No layer mapping for `{layer_name}`") |                 raise ValueError(f"No layer mapping for `{layer_name}`") | ||||||
|             _replace_forward(module, module_class) |             return fallback_forward(self, x, *args, **kwargs) | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         # Get kernel options for the device |         device = getattr(x, "device", None) | ||||||
|         repos = kernel.get(device_type) |         if device is None: | ||||||
|  |             return fallback_forward(self, x, *args, **kwargs) | ||||||
|  |  | ||||||
|         if repos is None: |         repo = kernel.get(Device(type=device.type)) | ||||||
|  |         if repo is None: | ||||||
|             if not use_fallback: |             if not use_fallback: | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     f"No layer mapping for `{layer_name}` with device type `{device_type}`" |                     f"No layer mapping for `{layer_name}` with device type `{device.type}`" | ||||||
|                 ) |                 ) | ||||||
|             _replace_forward(module, module_class) |             return fallback_forward(self, x, *args, **kwargs) | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         repo_with_mode = _select_repository( |         # Short-circuit if we already loaded the layer. | ||||||
|             repos, |         layer = cached_layer.get(repo, None) | ||||||
|             mode=mode, |         if layer is not None: | ||||||
|  |             if needs_backward and not getattr(layer, "has_backward", True): | ||||||
|  |                 return fallback_forward(self, x, *args, **kwargs) | ||||||
|  |             return layer.forward(self, x, *args, **kwargs) | ||||||
|  |  | ||||||
|  |         layer = _get_kernel_layer( | ||||||
|  |             repo_id=repo.repo_id, | ||||||
|  |             layer_name=repo.layer_name, | ||||||
|  |             revision=repo.revision, | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         if repo_with_mode is None: |         # We have to validate against the original signature. | ||||||
|             if not use_fallback: |         orig_forward = cls.forward | ||||||
|                 raise ValueError( |         try: | ||||||
|                     f"No repository for `{layer_name}` for configuration mode={mode}" |             cls.forward = fallback_forward | ||||||
|                 ) |             _validate_layer(check_cls=cls, cls=layer) | ||||||
|             _replace_forward(module, module_class) |         finally: | ||||||
|             continue |             cls.forward = orig_forward | ||||||
|  |  | ||||||
|         repo, repo_mode = repo_with_mode |         cached_layer[repo] = layer | ||||||
|  |  | ||||||
|         layer = _get_layer_memoize(repo, module_class) |         if needs_backward and not getattr(layer, "has_backward", True): | ||||||
|  |             return fallback_forward(self, x, *args, **kwargs) | ||||||
|  |         return layer.forward(self, x, *args, **kwargs) | ||||||
|  |  | ||||||
|         # Ideally we would do validation on the mapping where we check that |     cls.forward = forward | ||||||
|         # e.g. if a repo class is registered for TRAINING | TORCH_COMPILE, |  | ||||||
|         # the actual layer is compatible with that. Unfortunately, this would |  | ||||||
|         # mean that we have to pre-download everything. |  | ||||||
|         _validate_layer_has_mode( |  | ||||||
|             layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|         _conditionally_replace_forward( |  | ||||||
|             module=module, |  | ||||||
|             layer=layer, |  | ||||||
|             mode=mode, |  | ||||||
|             use_fallback=use_fallback, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     return model |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def use_kernel_forward_from_hub(layer_name: str): | def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True): | ||||||
|     """ |     """ | ||||||
|     Make a layer extensible using the name `layer_name`. |     Replace the forward function of a layer using a layer from the kernel hub. | ||||||
|  |  | ||||||
|  |     This decorator can be applied to a layer and replaces the forward method | ||||||
|  |     of the layer with that of a layer from the hub. The replacement is done | ||||||
|  |     when a layer matching `layer_name` and device type is registered through | ||||||
|  |     [`register_layer_mapping`]. The device type is inferred from the first | ||||||
|  |     argument to `forward`. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         layer_name (`str`): | ||||||
|  |             The name to assign to the layer. | ||||||
|  |         use_fallback (`bool`, *optional*, defaults to `True`): | ||||||
|  |             Whether to use the fallback forward function if no kernel mapping | ||||||
|  |             is found. If set to `False`, a `ValueError` will be raised if no kernel | ||||||
|  |             mapping is found. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         from kernels import use_kernel_forward_from_hub | ||||||
|  |  | ||||||
|  |         @use_kernel_forward_from_hub(layer_name="LlamaRMSNorm") | ||||||
|  |         class LlamaRMSNorm(nn.Module): | ||||||
|  |             def __init__(self, *args, **kwargs): | ||||||
|  |                 super().__init__(*args, **kwargs) | ||||||
|  |  | ||||||
|  |             def forward(self, x): | ||||||
|  |                 # Original forward implementation | ||||||
|  |                 pass | ||||||
|  |         ``` | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def decorator(cls): |     def decorator(cls): | ||||||
|         replace_kernel_forward_from_hub(cls, layer_name) |         replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback) | ||||||
|         return cls |         return cls | ||||||
|  |  | ||||||
|     return decorator |     return decorator | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_kernel_layer( | def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module": | ||||||
|     *, repo_id: str, layer_name: str, revision: str |  | ||||||
| ) -> Type["nn.Module"]: |  | ||||||
|     """Get a layer from a kernel.""" |     """Get a layer from a kernel.""" | ||||||
|  |  | ||||||
|     kernel = get_kernel(repo_id, revision=revision) |     kernel = get_kernel(repo_id, revision=revision) | ||||||
| @ -361,13 +270,13 @@ def _get_kernel_layer( | |||||||
|  |  | ||||||
|  |  | ||||||
| def _validate_layer(*, check_cls, cls): | def _validate_layer(*, check_cls, cls): | ||||||
|     import torch.nn as nn |  | ||||||
|  |  | ||||||
|     # The layer must have at least have the following properties: (1) it |     # The layer must have at least have the following properties: (1) it | ||||||
|     # must be stateless; (2) the forward signature should correspond to |     # must be stateless; (2) the forward signature should correspond to | ||||||
|     # the signature it is replacing; (3) forward should not call other |     # the signature it is replacing; (3) forward should not call other | ||||||
|     # methods. |     # methods. | ||||||
|  |  | ||||||
|  |     from torch import nn | ||||||
|  |  | ||||||
|     if not issubclass(cls, nn.Module): |     if not issubclass(cls, nn.Module): | ||||||
|         raise TypeError(f"Layer `{cls}` is not a Torch layer.") |         raise TypeError(f"Layer `{cls}` is not a Torch layer.") | ||||||
|  |  | ||||||
| @ -380,8 +289,7 @@ def _validate_layer(*, check_cls, cls): | |||||||
|     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} |     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} | ||||||
|     cls_members = {name for name, _ in inspect.getmembers(cls)} |     cls_members = {name for name, _ in inspect.getmembers(cls)} | ||||||
|     difference = cls_members - torch_module_members |     difference = cls_members - torch_module_members | ||||||
|     # verify if : difference ⊄ {"can_torch_compile", "has_backward"} |     if difference != set() and difference != {"has_backward"}: | ||||||
|     if not difference <= {"can_torch_compile", "has_backward"}: |  | ||||||
|         raise TypeError("Layer must not contain additional members.") |         raise TypeError("Layer must not contain additional members.") | ||||||
|  |  | ||||||
|     # Check whether the forward signatures are similar. |     # Check whether the forward signatures are similar. | ||||||
| @ -398,92 +306,3 @@ def _validate_layer(*, check_cls, cls): | |||||||
|             raise TypeError( |             raise TypeError( | ||||||
|                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" |                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _find_device(model: "nn.Module") -> Device: |  | ||||||
|     try: |  | ||||||
|         param = next(model.parameters()) |  | ||||||
|     except StopIteration: |  | ||||||
|         raise ValueError( |  | ||||||
|             "Cannot determine model device, provide as `device` argument to `kernelize`." |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     return Device(type=param.device.type) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _conditionally_replace_forward( |  | ||||||
|     *, |  | ||||||
|     module: "nn.Module", |  | ||||||
|     layer: Type["nn.Module"], |  | ||||||
|     mode: Mode, |  | ||||||
|     use_fallback: bool, |  | ||||||
| ): |  | ||||||
|     module_class = type(module) |  | ||||||
|  |  | ||||||
|     # Switch to fallback if the mode is not supported by the layer. |  | ||||||
|     # Note that this is useful even after _validate_layer_has_mode because |  | ||||||
|     # layers registered with the DEFAULT mode never get rejected by |  | ||||||
|     # _validate_layer_has_mode. For such layers, we want to fall back in |  | ||||||
|     # case the layer does not support the given mode. |  | ||||||
|     needs_fallback = Mode.TORCH_COMPILE in mode and not getattr( |  | ||||||
|         layer, "can_torch_compile", False |  | ||||||
|     ) |  | ||||||
|     needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True) |  | ||||||
|  |  | ||||||
|     if needs_fallback: |  | ||||||
|         if use_fallback: |  | ||||||
|             _replace_forward(module, module_class) |  | ||||||
|         else: |  | ||||||
|             raise ValueError(f"Available kernel does not support mode: {mode}") |  | ||||||
|     else: |  | ||||||
|         _replace_forward(module, layer) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]): |  | ||||||
|     module.forward = MethodType(layer.forward, module)  # type: ignore[method-assign] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _validate_layer_has_mode( |  | ||||||
|     *, |  | ||||||
|     layer_name: str, |  | ||||||
|     module: Type["nn.Module"], |  | ||||||
|     repo: LayerRepository, |  | ||||||
|     repo_mode: Mode, |  | ||||||
| ): |  | ||||||
|     """ |  | ||||||
|     Check that a repository supports the mode that it was registered for. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True): |  | ||||||
|         raise ValueError( |  | ||||||
|             f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n" |  | ||||||
|             f"Was registered for `{layer_name}` with mode `{repo_mode}`" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if Mode.TORCH_COMPILE in repo_mode and not getattr( |  | ||||||
|         module, "can_torch_compile", False |  | ||||||
|     ): |  | ||||||
|         raise ValueError( |  | ||||||
|             f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n" |  | ||||||
|             f"Was registered for `{layer_name}` with mode `{repo_mode}`" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     return True |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_layer_memoize( |  | ||||||
|     repo: LayerRepository, module_class: Type["nn.Module"] |  | ||||||
| ) -> Type["nn.Module"]: |  | ||||||
|     layer = _CACHED_LAYER.get(repo, None) |  | ||||||
|     if layer is not None: |  | ||||||
|         return layer |  | ||||||
|  |  | ||||||
|     layer = _get_kernel_layer( |  | ||||||
|         repo_id=repo.repo_id, |  | ||||||
|         layer_name=repo.layer_name, |  | ||||||
|         revision=repo.revision, |  | ||||||
|     ) |  | ||||||
|     _validate_layer(check_cls=module_class, cls=layer) |  | ||||||
|     _CACHED_LAYER[repo] = layer |  | ||||||
|  |  | ||||||
|     return layer |  | ||||||
|  | |||||||
| @ -43,23 +43,14 @@ def build_variant() -> str: | |||||||
|     elif torch.version.hip is not None: |     elif torch.version.hip is not None: | ||||||
|         rocm_version = parse(torch.version.hip.split("-")[0]) |         rocm_version = parse(torch.version.hip.split("-")[0]) | ||||||
|         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" |         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||||
|     elif torch.backends.mps.is_available(): |  | ||||||
|         compute_framework = "metal" |  | ||||||
|     else: |     else: | ||||||
|         raise AssertionError( |         raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.") | ||||||
|             "Torch was not compiled with CUDA, Metal, or ROCm enabled." |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     torch_version = parse(torch.__version__) |     torch_version = parse(torch.__version__) | ||||||
|  |     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||||
|     cpu = platform.machine() |     cpu = platform.machine() | ||||||
|     os = platform.system().lower() |     os = platform.system().lower() | ||||||
|  |  | ||||||
|     if os == "darwin": |  | ||||||
|         cpu = "aarch64" if cpu == "arm64" else cpu |  | ||||||
|         return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" |  | ||||||
|  |  | ||||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" |  | ||||||
|  |  | ||||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" |     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}" | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -110,23 +101,6 @@ def install_kernel( | |||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     try: |  | ||||||
|         return _load_kernel_from_path(repo_path, package_name, variant_locks) |  | ||||||
|     except FileNotFoundError: |  | ||||||
|         # Redo with more specific error message. |  | ||||||
|         raise FileNotFoundError( |  | ||||||
|             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def _load_kernel_from_path( |  | ||||||
|     repo_path: Path, |  | ||||||
|     package_name: str, |  | ||||||
|     variant_locks: Optional[Dict[str, VariantLock]] = None, |  | ||||||
| ) -> Tuple[str, Path]: |  | ||||||
|     variant = build_variant() |  | ||||||
|     universal_variant = universal_build_variant() |  | ||||||
|  |  | ||||||
|     variant_path = repo_path / "build" / variant |     variant_path = repo_path / "build" / variant | ||||||
|     universal_variant_path = repo_path / "build" / universal_variant |     universal_variant_path = repo_path / "build" / universal_variant | ||||||
|  |  | ||||||
| @ -145,7 +119,7 @@ def _load_kernel_from_path( | |||||||
|  |  | ||||||
|     if not os.path.exists(module_init_path): |     if not os.path.exists(module_init_path): | ||||||
|         raise FileNotFoundError( |         raise FileNotFoundError( | ||||||
|             f"Kernel at path `{repo_path}` does not have build: {variant}" |             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     return package_name, variant_path |     return package_name, variant_path | ||||||
| @ -184,27 +158,47 @@ def install_kernel_all_variants( | |||||||
|  |  | ||||||
| def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | ||||||
|     """ |     """ | ||||||
|     Download and import a kernel from the Hugging Face Hub. |     Load a kernel from the kernel hub. | ||||||
|  |  | ||||||
|     The kernel is downloaded from the repository `repo_id` at |     This function downloads a kernel to the local Hugging Face Hub cache | ||||||
|     branch/commit/tag `revision`. |     directory (if it was not downloaded before) and then loads the kernel. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): The Hub repository containing the kernel. | ||||||
|  |         revision (`str`, *optional*, defaults to `"main"`): The specific | ||||||
|  |             revision (branch, tag, or commit) to download. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |         ```python | ||||||
|  |         from kernels import get_kernel | ||||||
|  |         kernel = get_kernel("username/my-kernel") | ||||||
|  |         result = kernel.kernel_function(input_data) | ||||||
|  |         ``` | ||||||
|     """ |     """ | ||||||
|     package_name, package_path = install_kernel(repo_id, revision=revision) |     package_name, package_path = install_kernel(repo_id, revision=revision) | ||||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") |     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: |  | ||||||
|     """ |  | ||||||
|     Import a kernel from a local kernel repository path. |  | ||||||
|     """ |  | ||||||
|     package_name, package_path = _load_kernel_from_path(repo_path, package_name) |  | ||||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def has_kernel(repo_id: str, revision: str = "main") -> bool: | def has_kernel(repo_id: str, revision: str = "main") -> bool: | ||||||
|     """ |     """ | ||||||
|     Check whether a kernel build exists for the current environment |     Check whether a kernel build exists for the current environment | ||||||
|     (Torch version and compute framework). |  | ||||||
|  |     This function checks whether there exists a kernel build for the current | ||||||
|  |     environment (Torch version, compute framework and architecture). | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         revision (`str`, *optional*, defaults to `"main"`): | ||||||
|  |             The kernel revision. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `bool`: | ||||||
|  |             `True` if a compatible kernel build exists for the current environment, | ||||||
|  |             `False` otherwise. | ||||||
|     """ |     """ | ||||||
|     package_name = package_name_from_repo_id(repo_id) |     package_name = package_name_from_repo_id(repo_id) | ||||||
|     variant = build_variant() |     variant = build_variant() | ||||||
| @ -226,10 +220,25 @@ def has_kernel(repo_id: str, revision: str = "main") -> bool: | |||||||
|  |  | ||||||
| def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||||
|     """ |     """ | ||||||
|     Get a pre-downloaded, locked kernel. |     Loads a pre-downloaded, locked kernel module from the local cache. | ||||||
|  |  | ||||||
|     If `lockfile` is not specified, the lockfile will be loaded from the |     This function retrieves a kernel that was locked at a specific revision with | ||||||
|     caller's package metadata. |     `kernels lock <project>` and then downloaded with `kernels download <project>`. | ||||||
|  |  | ||||||
|  |     This function will fail if the kernel was not locked or downloaded. If you want | ||||||
|  |     the kernel to be downloaded when it is not in the cache, use [`get_locked_kernel`] | ||||||
|  |     instead. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         lockfile (`Optional[Path]`, *optional*, defaults to `None`): | ||||||
|  |             Path to a lockfile containing the commit SHA for the kernel. If `None`, | ||||||
|  |             the lock information is automatically retrieved from the metadata of the | ||||||
|  |             calling package. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module corresponding to the locked version. | ||||||
|     """ |     """ | ||||||
|     if lockfile is None: |     if lockfile is None: | ||||||
|         locked_sha = _get_caller_locked_kernel(repo_id) |         locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
| @ -274,7 +283,27 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | ||||||
|     """Get a kernel using a lock file.""" |     """ | ||||||
|  |     Loads a locked kernel module. | ||||||
|  |  | ||||||
|  |     This function retrieves a kernel that was locked at a specific revision with | ||||||
|  |     `kernels lock <project>`. | ||||||
|  |  | ||||||
|  |     This function will download the locked kernel when it is not available in the | ||||||
|  |     cache. If you want loading to fail if the kernel is not in the cache, use | ||||||
|  |     [`load_kernel`] instead. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         repo_id (`str`): | ||||||
|  |             The Hub repository containing the kernel. | ||||||
|  |         lockfile (`Optional[Path]`, *optional*, defaults to `None`): | ||||||
|  |             Path to a lockfile containing the commit SHA for the kernel. If `None`, | ||||||
|  |             the lock information is automatically retrieved from the metadata of the | ||||||
|  |             calling package. | ||||||
|  |  | ||||||
|  |     Returns: | ||||||
|  |         `ModuleType`: The imported kernel module corresponding to the locked version. | ||||||
|  |     """ | ||||||
|     locked_sha = _get_caller_locked_kernel(repo_id) |     locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
|  |  | ||||||
|     if locked_sha is None: |     if locked_sha is None: | ||||||
|  | |||||||
| @ -1,186 +0,0 @@ | |||||||
| import email.policy |  | ||||||
| import os |  | ||||||
| from dataclasses import dataclass |  | ||||||
| from email.message import Message |  | ||||||
| from importlib.metadata import PackageNotFoundError, version |  | ||||||
| from pathlib import Path |  | ||||||
| from typing import Optional |  | ||||||
|  |  | ||||||
| try: |  | ||||||
|     KERNELS_VERSION = version("kernels") |  | ||||||
| except PackageNotFoundError: |  | ||||||
|     KERNELS_VERSION = "unknown" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @dataclass |  | ||||||
| class Metadata: |  | ||||||
|     name: str |  | ||||||
|     version: str |  | ||||||
|     cuda_version: Optional[str] |  | ||||||
|     cxx_abi_version: Optional[str] |  | ||||||
|     torch_version: Optional[str] |  | ||||||
|     os: Optional[str] |  | ||||||
|     platform: Optional[str] |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def is_universal(self) -> bool: |  | ||||||
|         return self.platform is None |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_variant_to_wheel( |  | ||||||
|     repo_id: str, |  | ||||||
|     *, |  | ||||||
|     version: str, |  | ||||||
|     variant_path: Path, |  | ||||||
|     wheel_dir: Path, |  | ||||||
|     manylinux_version: str = "2.28", |  | ||||||
|     python_version: str = "3.9", |  | ||||||
| ) -> Path: |  | ||||||
|     """ |  | ||||||
|     Create a wheel file from the variant path. |  | ||||||
|     """ |  | ||||||
|     name = repo_id.split("/")[-1].replace("_", "-") |  | ||||||
|     metadata = extract_metadata(name, version, variant_path) |  | ||||||
|     return build_wheel( |  | ||||||
|         metadata, |  | ||||||
|         variant_path=variant_path, |  | ||||||
|         wheel_dir=wheel_dir, |  | ||||||
|         manylinux_version=manylinux_version, |  | ||||||
|         python_version=python_version, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata: |  | ||||||
|     """ |  | ||||||
|     Extract metadata from the variant path. |  | ||||||
|     """ |  | ||||||
|     if variant_path.name == "torch-universal": |  | ||||||
|         return Metadata( |  | ||||||
|             name=name, |  | ||||||
|             version=version, |  | ||||||
|             cuda_version=None, |  | ||||||
|             cxx_abi_version=None, |  | ||||||
|             torch_version=None, |  | ||||||
|             os=None, |  | ||||||
|             platform=None, |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     if not variant_path.name.startswith("torch"): |  | ||||||
|         raise ValueError("Currently only conversion of Torch kernels is supported.") |  | ||||||
|  |  | ||||||
|     variant_parts = variant_path.name.removeprefix("torch").split("-") |  | ||||||
|     if len(variant_parts) != 5: |  | ||||||
|         raise ValueError(f"Invalid variant name: {variant_path.name}") |  | ||||||
|  |  | ||||||
|     torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}" |  | ||||||
|     cpp_abi_version = variant_parts[1].removeprefix("cxx") |  | ||||||
|     cuda_version = variant_parts[2].removeprefix("cu") |  | ||||||
|     platform = variant_parts[3].replace("-", "_") |  | ||||||
|     os = variant_parts[4] |  | ||||||
|  |  | ||||||
|     return Metadata( |  | ||||||
|         name=name, |  | ||||||
|         version=version, |  | ||||||
|         cuda_version=cuda_version, |  | ||||||
|         cxx_abi_version=cpp_abi_version, |  | ||||||
|         torch_version=torch_version, |  | ||||||
|         os=os, |  | ||||||
|         platform=platform, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_wheel( |  | ||||||
|     metadata: Metadata, |  | ||||||
|     *, |  | ||||||
|     variant_path: Path, |  | ||||||
|     wheel_dir: Path, |  | ||||||
|     manylinux_version: str = "2.28", |  | ||||||
|     python_version: str = "3.9", |  | ||||||
| ) -> Path: |  | ||||||
|     """ |  | ||||||
|     Build the wheel file. |  | ||||||
|     """ |  | ||||||
|     try: |  | ||||||
|         from wheel.wheelfile import WheelFile  # type: ignore |  | ||||||
|     except ImportError: |  | ||||||
|         raise ImportError( |  | ||||||
|             "The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`" |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     name = metadata.name.replace("-", "_") |  | ||||||
|     python_version_flat = python_version.replace(".", "") |  | ||||||
|  |  | ||||||
|     if metadata.is_universal: |  | ||||||
|         python_tag = f"py{python_version_flat}" |  | ||||||
|         abi_tag = "none" |  | ||||||
|         platform_tag = "any" |  | ||||||
|         wheel_filename = ( |  | ||||||
|             f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl" |  | ||||||
|         ) |  | ||||||
|         dist_info_dir_name = f"{name}-{metadata.version}.dist-info" |  | ||||||
|         root_is_purelib = "true" |  | ||||||
|         requires_dist_torch = "torch" |  | ||||||
|     else: |  | ||||||
|         python_tag = f"cp{python_version_flat}" |  | ||||||
|         abi_tag = "abi3" |  | ||||||
|  |  | ||||||
|         if ( |  | ||||||
|             metadata.torch_version is None |  | ||||||
|             or metadata.cuda_version is None |  | ||||||
|             or metadata.cxx_abi_version is None |  | ||||||
|             or metadata.os is None |  | ||||||
|             or metadata.platform is None |  | ||||||
|         ): |  | ||||||
|             raise ValueError( |  | ||||||
|                 "Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels." |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}" |  | ||||||
|  |  | ||||||
|         if metadata.os == "linux": |  | ||||||
|             platform_tag = ( |  | ||||||
|                 f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}" |  | ||||||
|             ) |  | ||||||
|         else: |  | ||||||
|             platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}" |  | ||||||
|  |  | ||||||
|         wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" |  | ||||||
|         dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info" |  | ||||||
|         root_is_purelib = "false" |  | ||||||
|         requires_dist_torch = f"torch=={metadata.torch_version}.*" |  | ||||||
|  |  | ||||||
|     wheel_path = wheel_dir / wheel_filename |  | ||||||
|  |  | ||||||
|     wheel_msg = Message(email.policy.compat32) |  | ||||||
|     wheel_msg.add_header("Wheel-Version", "1.0") |  | ||||||
|     wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})") |  | ||||||
|     wheel_msg.add_header("Root-Is-Purelib", root_is_purelib) |  | ||||||
|     wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}") |  | ||||||
|  |  | ||||||
|     metadata_msg = Message(email.policy.compat32) |  | ||||||
|     metadata_msg.add_header("Metadata-Version", "2.1") |  | ||||||
|     metadata_msg.add_header("Name", name) |  | ||||||
|     metadata_msg.add_header("Version", metadata.version) |  | ||||||
|     metadata_msg.add_header("Summary", f"{name} kernel") |  | ||||||
|     metadata_msg.add_header("Requires-Python", ">=3.9") |  | ||||||
|     metadata_msg.add_header("Requires-Dist", requires_dist_torch) |  | ||||||
|  |  | ||||||
|     source_pkg_dir = variant_path / name |  | ||||||
|  |  | ||||||
|     with WheelFile(wheel_path, "w") as wheel_file: |  | ||||||
|         for root, dirnames, filenames in os.walk(source_pkg_dir): |  | ||||||
|             for filename in filenames: |  | ||||||
|                 if filename.endswith(".pyc"): |  | ||||||
|                     continue |  | ||||||
|  |  | ||||||
|                 abs_filepath = os.path.join(root, filename) |  | ||||||
|                 entry_name = os.path.relpath(abs_filepath, variant_path) |  | ||||||
|                 wheel_file.write(abs_filepath, entry_name) |  | ||||||
|  |  | ||||||
|         wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL") |  | ||||||
|         wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8")) |  | ||||||
|  |  | ||||||
|         metadata_path = os.path.join(dist_info_dir_name, "METADATA") |  | ||||||
|         wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8")) |  | ||||||
|  |  | ||||||
|     return wheel_path |  | ||||||
| @ -1,10 +0,0 @@ | |||||||
| import sys |  | ||||||
|  |  | ||||||
| import pytest |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def pytest_runtest_setup(item): |  | ||||||
|     if "linux_only" in item.keywords and not sys.platform.startswith("linux"): |  | ||||||
|         pytest.skip("skipping Linux-only test on non-Linux platform") |  | ||||||
|     if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): |  | ||||||
|         pytest.skip("skipping macOS-only test on non-macOS platform") |  | ||||||
| @ -1,82 +1,54 @@ | |||||||
| [ | [ | ||||||
|   { |   { | ||||||
|     "repo_id": "kernels-community/activation", |     "repo_id": "kernels-community/activation", | ||||||
|     "sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538", |     "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0", | ||||||
|     "variants": { |     "variants": { | ||||||
|       "torch25-cxx11-cu118-x86_64-linux": { |       "torch25-cxx11-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b", |         "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx11-cu121-x86_64-linux": { |       "torch25-cxx11-cu121-x86_64-linux": { | ||||||
|         "hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa", |         "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx11-cu124-x86_64-linux": { |       "torch25-cxx11-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965", |         "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu118-x86_64-linux": { |       "torch25-cxx98-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c", |         "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu121-x86_64-linux": { |       "torch25-cxx98-cu121-x86_64-linux": { | ||||||
|         "hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd", |         "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch25-cxx98-cu124-x86_64-linux": { |       "torch25-cxx98-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e", |         "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu118-x86_64-linux": { |       "torch26-cxx11-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f", |         "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu124-x86_64-linux": { |       "torch26-cxx11-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47", |         "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005", | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch26-cxx11-cu126-aarch64-linux": { |  | ||||||
|         "hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd", |  | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx11-cu126-x86_64-linux": { |       "torch26-cxx11-cu126-x86_64-linux": { | ||||||
|         "hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c", |         "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu118-x86_64-linux": { |       "torch26-cxx98-cu118-x86_64-linux": { | ||||||
|         "hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204", |         "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f", | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu124-x86_64-linux": { |       "torch26-cxx98-cu124-x86_64-linux": { | ||||||
|         "hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1", |         "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf", | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch26-cxx98-cu126-aarch64-linux": { |  | ||||||
|         "hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab", |  | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       }, |       }, | ||||||
|       "torch26-cxx98-cu126-x86_64-linux": { |       "torch26-cxx98-cu126-x86_64-linux": { | ||||||
|         "hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2", |         "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a", | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch27-cxx11-cu118-x86_64-linux": { |  | ||||||
|         "hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78", |  | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch27-cxx11-cu126-aarch64-linux": { |  | ||||||
|         "hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e", |  | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch27-cxx11-cu126-x86_64-linux": { |  | ||||||
|         "hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405", |  | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch27-cxx11-cu128-aarch64-linux": { |  | ||||||
|         "hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99", |  | ||||||
|         "hash_type": "git_lfs_concat" |  | ||||||
|       }, |  | ||||||
|       "torch27-cxx11-cu128-x86_64-linux": { |  | ||||||
|         "hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597", |  | ||||||
|         "hash_type": "git_lfs_concat" |         "hash_type": "git_lfs_concat" | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -1,7 +1,7 @@ | |||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel | from kernels import get_kernel, has_kernel | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @ -9,19 +9,6 @@ def kernel(): | |||||||
|     return get_kernel("kernels-community/activation") |     return get_kernel("kernels-community/activation") | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture |  | ||||||
| def local_kernel(): |  | ||||||
|     package_name, path = install_kernel("kernels-community/activation", "main") |  | ||||||
|     # Path is the build variant path (build/torch-<...>), so the grandparent |  | ||||||
|     # is the kernel repository path. |  | ||||||
|     return get_local_kernel(path.parent.parent, package_name) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture |  | ||||||
| def metal_kernel(): |  | ||||||
|     return get_kernel("kernels-test/relu-metal") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def universal_kernel(): | def universal_kernel(): | ||||||
|     return get_kernel("kernels-community/triton-scaled-mm") |     return get_kernel("kernels-community/triton-scaled-mm") | ||||||
| @ -34,7 +21,6 @@ def device(): | |||||||
|     return "cuda" |     return "cuda" | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_gelu_fast(kernel, device): | def test_gelu_fast(kernel, device): | ||||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) |     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
| @ -50,31 +36,6 @@ def test_gelu_fast(kernel, device): | |||||||
|     assert torch.allclose(y, expected) |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_local_kernel(local_kernel, device): |  | ||||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) |  | ||||||
|     y = torch.empty_like(x) |  | ||||||
|  |  | ||||||
|     local_kernel.gelu_fast(y, x) |  | ||||||
|  |  | ||||||
|     expected = torch.tensor( |  | ||||||
|         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], |  | ||||||
|         device=device, |  | ||||||
|         dtype=torch.float16, |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     assert torch.allclose(y, expected) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.darwin_only |  | ||||||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) |  | ||||||
| def test_relu_metal(metal_kernel, dtype): |  | ||||||
|     x = torch.arange(-10, 10, dtype=dtype, device="mps") |  | ||||||
|     y = metal_kernel.relu(x) |  | ||||||
|     assert torch.allclose(y, torch.relu(x)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| @pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||||
|     "kernel_exists", |     "kernel_exists", | ||||||
|     [ |     [ | ||||||
| @ -91,7 +52,6 @@ def test_has_kernel(kernel_exists): | |||||||
|     assert has_kernel(repo_id, revision=revision) == kernel |     assert has_kernel(repo_id, revision=revision) == kernel | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_universal_kernel(universal_kernel): | def test_universal_kernel(universal_kernel): | ||||||
|     torch.manual_seed(0) |     torch.manual_seed(0) | ||||||
|     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") |     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||||
|  | |||||||
| @ -16,21 +16,18 @@ def device(): | |||||||
|     return "cuda" |     return "cuda" | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_gelu_small(kernel, device, benchmark): | def test_gelu_small(kernel, device, benchmark): | ||||||
|     x = torch.randn(32, 32, dtype=torch.float16, device=device) |     x = torch.randn(32, 32, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|     benchmark(kernel.gelu_fast, y, x) |     benchmark(kernel.gelu_fast, y, x) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_gelu_medium(kernel, device, benchmark): | def test_gelu_medium(kernel, device, benchmark): | ||||||
|     x = torch.randn(128, 128, dtype=torch.float16, device=device) |     x = torch.randn(128, 128, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|     benchmark(kernel.gelu_fast, y, x) |     benchmark(kernel.gelu_fast, y, x) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_gelu_large(kernel, device, benchmark): | def test_gelu_large(kernel, device, benchmark): | ||||||
|     x = torch.randn(512, 512, dtype=torch.float16, device=device) |     x = torch.randn(512, 512, dtype=torch.float16, device=device) | ||||||
|     y = torch.empty_like(x) |     y = torch.empty_like(x) | ||||||
|  | |||||||
| @ -1,8 +1,6 @@ | |||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| import pytest |  | ||||||
|  |  | ||||||
| from kernels import load_kernel | from kernels import load_kernel | ||||||
| from kernels.cli import download_kernels | from kernels.cli import download_kernels | ||||||
|  |  | ||||||
| @ -19,7 +17,6 @@ def test_download_all_hash_validation(): | |||||||
|     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) |     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_load_locked(): | def test_load_locked(): | ||||||
|     project_dir = Path(__file__).parent / "kernel_locking" |     project_dir = Path(__file__).parent / "kernel_locking" | ||||||
|     # Also validates that hashing works correctly. |     # Also validates that hashing works correctly. | ||||||
|  | |||||||
| @ -1,5 +1,3 @@ | |||||||
| from contextlib import nullcontext |  | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| @ -8,8 +6,6 @@ from torch.nn import functional as F | |||||||
| from kernels import ( | from kernels import ( | ||||||
|     Device, |     Device, | ||||||
|     LayerRepository, |     LayerRepository, | ||||||
|     Mode, |  | ||||||
|     kernelize, |  | ||||||
|     register_kernel_mapping, |     register_kernel_mapping, | ||||||
|     use_kernel_forward_from_hub, |     use_kernel_forward_from_hub, | ||||||
| ) | ) | ||||||
| @ -20,18 +16,14 @@ kernel_layer_mapping = { | |||||||
|         Device(type="cuda"): LayerRepository( |         Device(type="cuda"): LayerRepository( | ||||||
|             repo_id="kernels-community/activation", |             repo_id="kernels-community/activation", | ||||||
|             layer_name="SiluAndMul", |             layer_name="SiluAndMul", | ||||||
|         ) |             revision="layers", | ||||||
|     }, |  | ||||||
|     "SiluAndMulNoCompile": { |  | ||||||
|         "cuda": LayerRepository( |  | ||||||
|             repo_id="kernels-test/op-without-fake-test", |  | ||||||
|             layer_name="SiluAndMul", |  | ||||||
|         ) |         ) | ||||||
|     }, |     }, | ||||||
|     "SiluAndMulStringDevice": { |     "SiluAndMulStringDevice": { | ||||||
|         "cuda": LayerRepository( |         "cuda": LayerRepository( | ||||||
|             repo_id="kernels-community/activation", |             repo_id="kernels-community/activation", | ||||||
|             layer_name="SiluAndMul", |             layer_name="SiluAndMul", | ||||||
|  |             revision="layers", | ||||||
|         ) |         ) | ||||||
|     }, |     }, | ||||||
| } | } | ||||||
| @ -51,11 +43,6 @@ class SiluAndMul(nn.Module): | |||||||
|         return F.silu(input[..., :d]) * input[..., d:] |         return F.silu(input[..., :d]) * input[..., d:] | ||||||
|  |  | ||||||
|  |  | ||||||
| @use_kernel_forward_from_hub("SiluAndMulNoCompile") |  | ||||||
| class SiluAndMulNoCompileKernel(SiluAndMul): |  | ||||||
|     pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @use_kernel_forward_from_hub("SiluAndMul") | @use_kernel_forward_from_hub("SiluAndMul") | ||||||
| class SiluAndMulWithKernel(SiluAndMul): | class SiluAndMulWithKernel(SiluAndMul): | ||||||
|     pass |     pass | ||||||
| @ -66,18 +53,6 @@ class SiluAndMulStringDevice(SiluAndMul): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| @use_kernel_forward_from_hub("Linear") |  | ||||||
| class TorchLinearWithCounter(nn.Linear): |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         super().__init__(*args, **kwargs) |  | ||||||
|         # Used to check that we called hub kernel. |  | ||||||
|         self.n_calls = 0 |  | ||||||
|  |  | ||||||
|     def forward(self, input: torch.Tensor) -> torch.Tensor: |  | ||||||
|         self.n_calls += 1 |  | ||||||
|         return super().forward(input) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_arg_kinds(): | def test_arg_kinds(): | ||||||
|     @use_kernel_forward_from_hub("ArgKind") |     @use_kernel_forward_from_hub("ArgKind") | ||||||
|     class ArgKind(nn.Module): |     class ArgKind(nn.Module): | ||||||
| @ -96,7 +71,6 @@ def test_arg_kinds(): | |||||||
|     assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5) |     assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | ||||||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||||
| def test_hub_forward(cls, device): | def test_hub_forward(cls, device): | ||||||
| @ -106,7 +80,7 @@ def test_hub_forward(cls, device): | |||||||
|     X = torch.randn((32, 64), device=device) |     X = torch.randn((32, 64), device=device) | ||||||
|     Y = silu_and_mul(X) |     Y = silu_and_mul(X) | ||||||
|  |  | ||||||
|     silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE) |     silu_and_mul_with_kernel = cls() | ||||||
|     Y_kernel = silu_and_mul_with_kernel(X) |     Y_kernel = silu_and_mul_with_kernel(X) | ||||||
|  |  | ||||||
|     torch.testing.assert_close(Y_kernel, Y) |     torch.testing.assert_close(Y_kernel, Y) | ||||||
| @ -124,70 +98,11 @@ def test_layer_fallback_works(): | |||||||
|         pass |         pass | ||||||
|  |  | ||||||
|     # Check that we don't raise an exception for a non-existing kernel. |     # Check that we don't raise an exception for a non-existing kernel. | ||||||
|     silu_and_mul = SiluAndMulWithKernelFallback() |     SiluAndMulWithKernelFallback() | ||||||
|     kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) |  | ||||||
| @pytest.mark.parametrize("device", ["cuda"]) |  | ||||||
| def test_torch_compile_layer_without_fallback(cls, device): |  | ||||||
|     silu_and_mul = SiluAndMul() |  | ||||||
|  |  | ||||||
|     X = torch.randn((32, 64), dtype=torch.float32, device=device) |  | ||||||
|     Y = silu_and_mul(X) |  | ||||||
|  |  | ||||||
|     silu_and_mul_with_kernel = cls() |  | ||||||
|     silu_and_mul_with_kernel.eval() |  | ||||||
|  |  | ||||||
|     ctx = ( |  | ||||||
|         pytest.raises(ValueError, match="does not support mode") |  | ||||||
|         if cls is SiluAndMulNoCompileKernel |  | ||||||
|         else nullcontext() |  | ||||||
|     ) |  | ||||||
|     with ctx: |  | ||||||
|         silu_and_mul_with_kernel = kernelize( |  | ||||||
|             silu_and_mul_with_kernel, |  | ||||||
|             device=device, |  | ||||||
|             mode=Mode.INFERENCE | Mode.TORCH_COMPILE, |  | ||||||
|             use_fallback=False, |  | ||||||
|         ) |  | ||||||
|     silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True) |  | ||||||
|  |  | ||||||
|     Y_compiled = silu_and_mul_compiled(X) |  | ||||||
|  |  | ||||||
|     torch.testing.assert_close(Y_compiled, Y) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel]) |  | ||||||
| @pytest.mark.parametrize("device", ["cuda"]) |  | ||||||
| def test_torch_compile_layer_with_fallback(cls, device): |  | ||||||
|     silu_and_mul = SiluAndMul() |  | ||||||
|  |  | ||||||
|     X = torch.randn((32, 64), dtype=torch.float32, device=device) |  | ||||||
|     Y = silu_and_mul(X) |  | ||||||
|  |  | ||||||
|     silu_and_mul_with_kernel = cls() |  | ||||||
|     silu_and_mul_with_kernel.eval() |  | ||||||
|     silu_and_mul_with_kernel = kernelize( |  | ||||||
|         silu_and_mul_with_kernel, |  | ||||||
|         device=device, |  | ||||||
|         mode=Mode.INFERENCE | Mode.TORCH_COMPILE, |  | ||||||
|     ) |  | ||||||
|     silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True) |  | ||||||
|  |  | ||||||
|     Y_compiled = silu_and_mul_compiled(X) |  | ||||||
|  |  | ||||||
|     torch.testing.assert_close(Y_compiled, Y) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_mapping_contexts(): | def test_mapping_contexts(): | ||||||
|     assert set(_KERNEL_MAPPING.get().keys()) == { |     assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"} | ||||||
|         "SiluAndMul", |  | ||||||
|         "SiluAndMulStringDevice", |  | ||||||
|         "SiluAndMulNoCompile", |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     extra_mapping1 = { |     extra_mapping1 = { | ||||||
|         "TestKernel": { |         "TestKernel": { | ||||||
| @ -203,7 +118,6 @@ def test_mapping_contexts(): | |||||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { |         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|             "SiluAndMul", |             "SiluAndMul", | ||||||
|             "SiluAndMulStringDevice", |             "SiluAndMulStringDevice", | ||||||
|             "SiluAndMulNoCompile", |  | ||||||
|             "TestKernel", |             "TestKernel", | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @ -221,26 +135,20 @@ def test_mapping_contexts(): | |||||||
|             assert set(_KERNEL_MAPPING.get().keys()) == { |             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|                 "SiluAndMul", |                 "SiluAndMul", | ||||||
|                 "SiluAndMulStringDevice", |                 "SiluAndMulStringDevice", | ||||||
|                 "SiluAndMulNoCompile", |  | ||||||
|                 "TestKernel", |                 "TestKernel", | ||||||
|             } |             } | ||||||
|             assert ( |             assert ( | ||||||
|                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][ |                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|                     Mode.DEFAULT |  | ||||||
|                 ].repo_id |  | ||||||
|                 == "kernels-community/non-existing" |                 == "kernels-community/non-existing" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { |         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|             "SiluAndMul", |             "SiluAndMul", | ||||||
|             "SiluAndMulStringDevice", |             "SiluAndMulStringDevice", | ||||||
|             "SiluAndMulNoCompile", |  | ||||||
|             "TestKernel", |             "TestKernel", | ||||||
|         } |         } | ||||||
|         assert ( |         assert ( | ||||||
|             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][ |             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|                 Mode.DEFAULT |  | ||||||
|             ].repo_id |  | ||||||
|             == "kernels-community/activation" |             == "kernels-community/activation" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @ -249,29 +157,23 @@ def test_mapping_contexts(): | |||||||
|                 "SiluAndMul", |                 "SiluAndMul", | ||||||
|             } |             } | ||||||
|             assert ( |             assert ( | ||||||
|                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][ |                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|                     Mode.DEFAULT |  | ||||||
|                 ].repo_id |  | ||||||
|                 == "kernels-community/non-existing" |                 == "kernels-community/non-existing" | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|         assert set(_KERNEL_MAPPING.get().keys()) == { |         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|             "SiluAndMul", |             "SiluAndMul", | ||||||
|             "SiluAndMulStringDevice", |             "SiluAndMulStringDevice", | ||||||
|             "SiluAndMulNoCompile", |  | ||||||
|             "TestKernel", |             "TestKernel", | ||||||
|         } |         } | ||||||
|         assert ( |         assert ( | ||||||
|             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][ |             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|                 Mode.DEFAULT |  | ||||||
|             ].repo_id |  | ||||||
|             == "kernels-community/activation" |             == "kernels-community/activation" | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     assert set(_KERNEL_MAPPING.get().keys()) == { |     assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|         "SiluAndMul", |         "SiluAndMul", | ||||||
|         "SiluAndMulStringDevice", |         "SiluAndMulStringDevice", | ||||||
|         "SiluAndMulNoCompile", |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -303,174 +205,20 @@ def test_validate_kernel_layer(): | |||||||
|         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) |         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_invalid_mode_for_mapping_rejected(): |  | ||||||
|     linear = TorchLinearWithCounter(32, 32).to("cuda") |  | ||||||
|  |  | ||||||
|     with use_kernel_mapping( |  | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 "cuda": { |  | ||||||
|                     Mode.TRAINING: LayerRepository( |  | ||||||
|                         repo_id="kernels-test/backward-marker-test", |  | ||||||
|                         layer_name="LinearNoBackward", |  | ||||||
|                     ) |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         with pytest.raises(ValueError, match="does not support backward"): |  | ||||||
|             kernelize(linear, mode=Mode.TRAINING) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_kernel_modes(): |  | ||||||
|     linear = TorchLinearWithCounter(32, 32).to("cuda") |  | ||||||
|  |  | ||||||
|     # Case 1: layer without further specification, becomes the |  | ||||||
|     #         base layer. |  | ||||||
|     with use_kernel_mapping( |  | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 "cuda": LayerRepository( |  | ||||||
|                     repo_id="kernels-test/backward-marker-test", |  | ||||||
|                     layer_name="LinearBackward", |  | ||||||
|                 ) |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |  | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 0 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING) |  | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 0 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) |  | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 0 |  | ||||||
|  |  | ||||||
|     # Case 2: register a kernel just for training. If no base kernel |  | ||||||
|     #         layer is registered, we fall back to the original layer. |  | ||||||
|     with use_kernel_mapping( |  | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 "cuda": { |  | ||||||
|                     Mode.TRAINING: LayerRepository( |  | ||||||
|                         repo_id="kernels-test/backward-marker-test", |  | ||||||
|                         layer_name="LinearBackward", |  | ||||||
|                     ) |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |  | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 1 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING) |  | ||||||
|         linear(X) |  | ||||||
|         # Training has a kernel, so fallback. |  | ||||||
|         assert linear.n_calls == 1 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) |  | ||||||
|         linear(X) |  | ||||||
|         # No kernel for training + torch.compile, so fallback. |  | ||||||
|         assert linear.n_calls == 2 |  | ||||||
|  |  | ||||||
|     # Case 3: register a kernel just for training and one for fallback. |  | ||||||
|     with use_kernel_mapping( |  | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 "cuda": { |  | ||||||
|                     Mode.DEFAULT: LayerRepository( |  | ||||||
|                         repo_id="kernels-test/backward-marker-test", |  | ||||||
|                         layer_name="LinearBackward", |  | ||||||
|                     ), |  | ||||||
|                     Mode.TRAINING: LayerRepository( |  | ||||||
|                         repo_id="kernels-test/backward-marker-test", |  | ||||||
|                         layer_name="LinearBackward", |  | ||||||
|                     ), |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |  | ||||||
|         linear(X) |  | ||||||
|         # Uses the base kernel. |  | ||||||
|         assert linear.n_calls == 2 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING) |  | ||||||
|         linear(X) |  | ||||||
|         # Uses the training kernel. |  | ||||||
|         assert linear.n_calls == 2 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) |  | ||||||
|         linear(X) |  | ||||||
|         # Uses the base kernel. |  | ||||||
|         assert linear.n_calls == 2 |  | ||||||
|  |  | ||||||
|     # Case 4: register a kernel with two preferences. |  | ||||||
|     with use_kernel_mapping( |  | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 "cuda": { |  | ||||||
|                     Mode.TRAINING |  | ||||||
|                     | Mode.TORCH_COMPILE: LayerRepository( |  | ||||||
|                         repo_id="kernels-test/backward-marker-test", |  | ||||||
|                         layer_name="LinearBackward", |  | ||||||
|                     ) |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |  | ||||||
|         linear(X) |  | ||||||
|         # No inference kernel, so fallback. |  | ||||||
|         assert linear.n_calls == 3 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING) |  | ||||||
|         linear(X) |  | ||||||
|         # No training kernel, so fallback. |  | ||||||
|         assert linear.n_calls == 4 |  | ||||||
|  |  | ||||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) |  | ||||||
|         linear(X) |  | ||||||
|         # We do have a training + torch.compile kernel. |  | ||||||
|         assert linear.n_calls == 4 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.linux_only |  | ||||||
| def test_fallback_used_when_training(): | def test_fallback_used_when_training(): | ||||||
|     linear = TorchLinearWithCounter(32, 32).to("cuda") |     @use_kernel_forward_from_hub("Linear") | ||||||
|  |     class TorchLinear(nn.Linear): | ||||||
|  |         def __init__(self, *args, **kwargs): | ||||||
|  |             super().__init__(*args, **kwargs) | ||||||
|  |             # Used to check that we called hub kernel. | ||||||
|  |             self.n_calls = 0 | ||||||
|  |  | ||||||
|     # Case 1: kernel with explicit backward support should always |         def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|     #         use the kernel. |             self.n_calls += 1 | ||||||
|     with use_kernel_mapping( |             return super().forward(input) | ||||||
|         { |  | ||||||
|             "Linear": { |  | ||||||
|                 Device(type="cuda"): LayerRepository( |  | ||||||
|                     repo_id="kernels-test/backward-marker-test", |  | ||||||
|                     layer_name="LinearBackward", |  | ||||||
|                 ) |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     ): |  | ||||||
|         linear.train() |  | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |  | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 0 |  | ||||||
|  |  | ||||||
|         linear.eval() |     linear = TorchLinear(32, 32).to("cuda") | ||||||
|         linear(X) |  | ||||||
|         assert linear.n_calls == 0 |  | ||||||
|  |  | ||||||
|     # Case 2: kernel with implicit backward support should always |  | ||||||
|     #         use the kernel. |  | ||||||
|     with use_kernel_mapping( |     with use_kernel_mapping( | ||||||
|         { |         { | ||||||
|             "Linear": { |             "Linear": { | ||||||
| @ -482,7 +230,6 @@ def test_fallback_used_when_training(): | |||||||
|         } |         } | ||||||
|     ): |     ): | ||||||
|         linear.train() |         linear.train() | ||||||
|         kernelize(linear, mode=Mode.INFERENCE) |  | ||||||
|         X = torch.randn(10, 32, device="cuda") |         X = torch.randn(10, 32, device="cuda") | ||||||
|         linear(X) |         linear(X) | ||||||
|         assert linear.n_calls == 0 |         assert linear.n_calls == 0 | ||||||
| @ -491,18 +238,40 @@ def test_fallback_used_when_training(): | |||||||
|         linear(X) |         linear(X) | ||||||
|         assert linear.n_calls == 0 |         assert linear.n_calls == 0 | ||||||
|  |  | ||||||
|  |     with use_kernel_mapping( | ||||||
| def test_invalid_mode_rejected(): |         { | ||||||
|     with pytest.raises(ValueError, match="mutually exclusive"): |             "Linear": { | ||||||
|         _ = Mode.INFERENCE | Mode.TRAINING |                 Device(type="cuda"): LayerRepository( | ||||||
|  |                     repo_id="kernels-test/backward-marker-test", | ||||||
|     with pytest.raises(ValueError, match="cannot be combined with other modes"): |                     layer_name="LinearBackward", | ||||||
|         _ = Mode.DEFAULT | Mode.TORCH_COMPILE |                 ) | ||||||
|  |             } | ||||||
|     with pytest.raises( |         } | ||||||
|         ValueError, match="can only be used to register kernel mappings" |  | ||||||
|     ): |     ): | ||||||
|         kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT) |         linear.train() | ||||||
|  |         X = torch.randn(10, 32, device="cuda") | ||||||
|  |         linear(X) | ||||||
|  |         assert linear.n_calls == 0 | ||||||
|  |  | ||||||
|     with pytest.raises(ValueError, match="mode must contain"): |         linear.eval() | ||||||
|         kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE) |         linear(X) | ||||||
|  |         assert linear.n_calls == 0 | ||||||
|  |  | ||||||
|  |     with use_kernel_mapping( | ||||||
|  |         { | ||||||
|  |             "Linear": { | ||||||
|  |                 Device(type="cuda"): LayerRepository( | ||||||
|  |                     repo_id="kernels-test/backward-marker-test", | ||||||
|  |                     layer_name="LinearNoBackward", | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     ): | ||||||
|  |         linear.train() | ||||||
|  |         X = torch.randn(10, 32, device="cuda") | ||||||
|  |         linear(X) | ||||||
|  |         assert linear.n_calls == 1 | ||||||
|  |  | ||||||
|  |         linear.eval() | ||||||
|  |         linear(X) | ||||||
|  |         assert linear.n_calls == 1 | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user
	