mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-27 00:54:28 +08:00 
			
		
		
		
	Compare commits
	
		
			23 Commits
		
	
	
		
			type-kerne
			...
			compile-no
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 29a930487a | |||
| b0c431fee4 | |||
| 9a188eadbe | |||
| 457c7c1b8d | |||
| fb8cd99a2c | |||
| dfee307d54 | |||
| 93e5765611 | |||
| bf488208be | |||
| 2a14472e4c | |||
| 055a953552 | |||
| 692d5ad458 | |||
| 2139df57f4 | |||
| 8f9a77bb6a | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 | |||
| 07e5e8481a | |||
| 88f55d4728 | |||
| e801ebf332 | |||
| 0ae07f05fc | |||
| 7611021100 | |||
| 767e7ccf13 | |||
| 1caa4c1393 | 
							
								
								
									
										15
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
								
							| @ -51,7 +51,15 @@ jobs: | ||||
|         run: uv run mypy src/kernels | ||||
|  | ||||
|       - name: Run tests | ||||
|         run: uv run pytest tests | ||||
|         run: | | ||||
|           uv run pytest tests | ||||
|  | ||||
|       - name: Run staging tests | ||||
|         env: | ||||
|           HF_TOKEN: ${{ secrets.HF_STAGING_TOKEN }} | ||||
|         run: | | ||||
|           HUGGINGFACE_CO_STAGING=true uv run pytest --token -m "is_staging_test" tests/ | ||||
|         if: matrix.python_version == '3.10' && matrix.torch-version == '2.7.0' | ||||
|  | ||||
|       - name: Check kernel conversion | ||||
|         run: | | ||||
| @ -65,6 +73,11 @@ jobs: | ||||
|         run: | | ||||
|           uv run kernels generate-readme kernels-community/triton-layer-norm | ||||
|  | ||||
|       - name: Check kernel check | ||||
|         run: | | ||||
|           uv pip install kernel-abi-check | ||||
|           kernels check kernels-community/activation | ||||
|  | ||||
|       - name: Import check without torch | ||||
|         run: | | ||||
|           uv pip uninstall torch | ||||
|  | ||||
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| .PHONY: style | ||||
|  | ||||
| export check_dirs := src examples tests | ||||
|  | ||||
| style: | ||||
| 	black ${check_dirs} | ||||
| 	isort ${check_dirs} | ||||
| 	ruff check ${check_dirs} --fix | ||||
| @ -62,7 +62,6 @@ the Hub. | ||||
| - [Using layers](docs/source/layers.md) | ||||
| - [Locking kernel/layer versions](docs/source/locking.md) | ||||
| - [Environment variables](docs/source/env.md) | ||||
| - [Using kernels in a Docker container](docs/source/docker.md) | ||||
| - [Kernel requirements](docs/source/kernel-requirements.md) | ||||
| - [Frequently Asked Questions](docs/source/faq.md) | ||||
| - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||
|  | ||||
| @ -21,6 +21,8 @@ | ||||
|       title: Kernels | ||||
|     - local: api/layers | ||||
|       title: Layers | ||||
|     - local: cli | ||||
|       title: Kernels CLI | ||||
|   title: API Reference | ||||
| - sections: | ||||
|     - local: kernel-requirements | ||||
|  | ||||
| @ -21,6 +21,22 @@ activation.gelu_fast(y, x) | ||||
| print(y) | ||||
| ``` | ||||
|  | ||||
| ### Using version bounds | ||||
|  | ||||
| Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||
| You can specify which version to download using Python version specifiers: | ||||
|  | ||||
| ```python | ||||
| import torch | ||||
| from kernels import get_kernel | ||||
|  | ||||
| activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0") | ||||
| ``` | ||||
|  | ||||
| This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It | ||||
| is strongly recommended to specify a version bound, since a kernel author | ||||
| might push incompatible changes to the `main` branch. | ||||
|  | ||||
| ## Checking Kernel Availability | ||||
|  | ||||
| You can check if a specific kernel is available for your environment: | ||||
|  | ||||
							
								
								
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								docs/source/cli.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,58 @@ | ||||
| # Kernels CLI Reference | ||||
|  | ||||
| ## Main Functions | ||||
|  | ||||
| ### kernels check | ||||
|  | ||||
| You can use `kernels check` to test compliance of a kernel on the Hub. | ||||
| This currently checks that the kernel: | ||||
|  | ||||
| - Supports the currently-required Python ABI version. | ||||
| - Works on supported operating system versions. | ||||
|  | ||||
| For example: | ||||
|  | ||||
| ```bash | ||||
| $ kernels check kernels-community/flash-attn3 | ||||
| Checking variant: torch28-cxx11-cu128-aarch64-linux | ||||
|   🐍 Python ABI 3.9 compatible | ||||
|   🐧 manylinux_2_28 compatible | ||||
| [...] | ||||
| ``` | ||||
|  | ||||
| ### kernels to-wheel | ||||
|  | ||||
| We strongly recommend downloading kernels from the Hub using the `kernels` | ||||
| package, since this comes with large [benefits](index.md) over using Python | ||||
| wheels. That said, some projects may require deployment of kernels as | ||||
| wheels. The `kernels` utility provides a simple solution to this. You can | ||||
| convert any Hub kernel into a set of wheels with the `to-wheel` command: | ||||
|  | ||||
| ```bash | ||||
| $ kernels to-wheel drbh/img2grey 1.1.2 | ||||
| ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl | ||||
| ``` | ||||
|  | ||||
| ### kernels upload | ||||
|  | ||||
| Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload | ||||
| your kernel builds to the Hub. | ||||
|  | ||||
| **Notes**: | ||||
|  | ||||
| - This will take care of creating a repository on the Hub with the `repo_id` provided. | ||||
| - If a repo with the `repo_id` already exists and if it contains a `build` with the build variant | ||||
|   being uploaded, it will attempt to delete the files existing under it. | ||||
| - Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub. | ||||
| @ -1,6 +1,8 @@ | ||||
| # FAQ | ||||
|  | ||||
| ## Why is the kernelization step needed? | ||||
| ## Kernel layers | ||||
|  | ||||
| ### Why is the kernelization step needed as a separate step? | ||||
|  | ||||
| In earlier versions of `kernels`, a layer's `forward` method was replaced | ||||
| by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. | ||||
| @ -11,3 +13,29 @@ on data-dependent branching. | ||||
|  | ||||
| To avoid branching, we have to make dispatch decisions ahead of time, | ||||
| which is what the `kernelize` function does. | ||||
|  | ||||
| ### Why does kernelization only replace `forward` methods? | ||||
|  | ||||
| There are some other possible approaches. The first is to completely | ||||
| replace existing layers by kernel layers. However, since this would | ||||
| permit free-form layer classes, it would be much harder to validate | ||||
| that layers are fully compatible with the layers that they are | ||||
| replacing. For instance, they could have completely different member | ||||
| variables. Besides that, we would also need to hold on to the original | ||||
| layers, in case we need to revert to the base layers when the model | ||||
| is `kernelize`d again with different options. | ||||
|  | ||||
| A second approach would be to make an auxiliary layer that wraps the | ||||
| original layer and the kernel layer and dispatches to the kernel layer. | ||||
| This wouldn't have the issues of the first approach, because kernel layers | ||||
| could be similarly strict as they are now, and we would still have access | ||||
| to the original layers when `kernelize`-ing the model again. However, | ||||
| this would change the graph structure of the model and would break use | ||||
| cases where programs access the model internals (e.g. | ||||
| `model.layers[0].attention.query_weight`) or rely on the graph structure | ||||
| in other ways. | ||||
|  | ||||
| The approach of `forward`-replacement is the least invasive, because | ||||
| it preserves the original model graph. It is also reversible, since | ||||
| even though the `forward` of a layer _instance_ might be replaced, | ||||
| the corresponding class still has the original `forward`. | ||||
|  | ||||
| @ -34,6 +34,8 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of | ||||
| the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md) | ||||
| to resolve the version constraints. | ||||
|  | ||||
| We recommend using [semver](https://semver.org/) to version kernels. | ||||
|  | ||||
| ## Native Python module | ||||
|  | ||||
| Kernels will typically contain a native Python module with precompiled | ||||
| @ -44,19 +46,28 @@ have dynamic library dependencies outside: | ||||
| - Torch; | ||||
| - CUDA/ROCm libraries installed as dependencies of Torch. | ||||
|  | ||||
| ## Compatibility with torch.compile | ||||
|  | ||||
| The Kernel Hub also encourages to write the kernels in a `torch.compile` | ||||
| compliant way. This helps to ensure that the kernels are compatible with | ||||
| `torch.compile` without introducing any graph breaks and triggering  | ||||
| recompilation which can limit the benefits of compilation. | ||||
|  | ||||
| [Here](https://github.com/huggingface/kernel-builder/blob/d1ee9bf9301ac8c5199099d90ee1c9d5c789d5ba/examples/relu-backprop-compile/tests/test_relu.py#L162) is a simple test example which checks for graph breaks and  | ||||
| recompilation triggers during `torch.compile`. | ||||
|  | ||||
| ### Linux | ||||
|  | ||||
| - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||
|   for compatibility with Python 3.9 and later. | ||||
| - Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based). | ||||
|   This means that the extension **must not** use symbols versions higher than: | ||||
|  | ||||
|   - GLIBC 2.28 | ||||
|   - GLIBCXX 3.4.24 | ||||
|   - CXXABI 1.3.11 | ||||
|   - GCC 7.0.0 | ||||
|  | ||||
| These requirement can be checked with the ABI checker (see below). | ||||
| These requirements can be checked with the ABI checker (see below). | ||||
|  | ||||
| ### macOS | ||||
|  | ||||
|  | ||||
| @ -5,7 +5,7 @@ the Hub can replace the `forward` method of an existing layer for a certain | ||||
| device type. This makes it possible to provide more performant kernels for | ||||
| existing layers. | ||||
|  | ||||
| See [Kernel requirements](kernel-requirements.md) for more information the | ||||
| See [Kernel requirements](kernel-requirements.md) for more information on the | ||||
| requirements of Hub layers. | ||||
|  | ||||
| ## Making a layer extensible with kernels from the hub | ||||
| @ -84,12 +84,6 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE) | ||||
| model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
| ``` | ||||
|  | ||||
| When the `mode` argument is not specified, | ||||
| `Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode | ||||
| aligns most closely with pure PyTorch layers which also support training | ||||
| and `torch.compile`. However, to select the most performant kernels, it | ||||
| is often good to make the mode specific as possible. | ||||
|  | ||||
| ### Kernel device | ||||
|  | ||||
| Kernels can be registered per device type. For instance, separate `cuda` and | ||||
| @ -117,7 +111,7 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback= | ||||
|  | ||||
| This can be useful if you want to guarantee that Hub kernels are used. | ||||
|  | ||||
| ### Inspecting kernels which kernels are used | ||||
| ### Inspecting which kernels are used | ||||
|  | ||||
| The kernels that are used are logged at the `INFO` level by `kernelize`. | ||||
| See the [Python logging](https://docs.python.org/3/library/logging.html) | ||||
| @ -157,12 +151,39 @@ used with the `use_kernel_mapping` context manager: | ||||
| ```python | ||||
| with use_kernel_mapping(kernel_layer_mapping): | ||||
|     # Use the layer for which the mapping is applied. | ||||
|     model = kernelize(model) | ||||
|     model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
| ``` | ||||
|  | ||||
| This ensures that the mapping is not active anymore outside the | ||||
| `with`-scope. | ||||
|  | ||||
| ### Using version bounds | ||||
|  | ||||
| Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`. | ||||
| You can specify which version of the kernel to download using Python version | ||||
| specifiers: | ||||
|  | ||||
| ```python | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         "cuda": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             version=">=0.0.4,<0.1.0", | ||||
|         ), | ||||
|         "rocm": LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|             version=">=0.0.4,<0.1.0", | ||||
|         ) | ||||
|     } | ||||
| } | ||||
| ``` | ||||
|  | ||||
| This will get the layer from latest kernel tagged `v0.0.z` where `z` is at | ||||
| least 4. It is strongly recommended to specify a version bound, since a | ||||
| kernel author might push incompatible changes to the `main` branch. | ||||
|  | ||||
| ### Registering kernels for specific modes | ||||
|  | ||||
| You might want to register two different kernels for a particular layer, | ||||
| @ -265,7 +286,6 @@ Capabilities behave as follows: | ||||
|   an existing kernel, the new kernel will replace the old kernel. | ||||
| - When there are multiple kernels that support a capability, the kernel | ||||
|   with the smaller capability interval will be used. E.g. given: | ||||
|  | ||||
|   - `KernelA` with `min_capability=80` and `max_capability=89`; | ||||
|   - `KernelB` with `min_capability=75` and `max_capability=89`; | ||||
|   - `kernelize` runs on a system with capability 8.6. | ||||
|  | ||||
| @ -20,11 +20,11 @@ activation.gelu_fast(y, x) | ||||
| print("Kernel successfully executed") | ||||
|  | ||||
| # Check results | ||||
| expected = torch.tensor([ | ||||
|     [0.8408, 1.9551, 2.9961], | ||||
|     [4.0000, 5.0000, 6.0000], | ||||
|     [7.0000, 8.0000, 9.0000] | ||||
| ], device='cuda:0', dtype=torch.float16) | ||||
| expected = torch.tensor( | ||||
|     [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||
|     device="cuda:0", | ||||
|     dtype=torch.float16, | ||||
| ) | ||||
| assert torch.allclose(y, expected) | ||||
|  | ||||
| print("Calculated values are exact") | ||||
|  | ||||
| @ -24,6 +24,7 @@ | ||||
|       in | ||||
|       { | ||||
|         formatter = pkgs.nixfmt-tree; | ||||
|         packages.kernel-abi-check = pkgs.python3.pkgs.callPackage ./nix/kernel-abi-check.nix {}; | ||||
|         devShells = with pkgs; rec { | ||||
|           default = mkShell { | ||||
|             nativeBuildInputs = [ | ||||
| @ -40,6 +41,7 @@ | ||||
|               ++ (with python3.pkgs; [ | ||||
|                 docutils | ||||
|                 huggingface-hub | ||||
|                 (callPackage ./nix/kernel-abi-check.nix {}) | ||||
|                 mktestdocs | ||||
|                 pytest | ||||
|                 pytest-benchmark | ||||
|  | ||||
							
								
								
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								nix/kernel-abi-check.nix
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| { | ||||
|   buildPythonPackage, | ||||
|   fetchPypi, | ||||
|   rustPlatform, | ||||
| }: | ||||
|  | ||||
| buildPythonPackage rec { | ||||
|   pname = "kernel-abi-check"; | ||||
|   version = "0.6.2"; | ||||
|  | ||||
|   src = fetchPypi { | ||||
|     inherit version; | ||||
|     pname = "kernel_abi_check"; | ||||
|     hash = "sha256-goWC7SK79FVNEvkp3bISBwbOqdSrmobANtrWIve9/Ys="; | ||||
|   }; | ||||
|  | ||||
|   cargoDeps = rustPlatform.fetchCargoVendor { | ||||
|     inherit pname version src sourceRoot; | ||||
|     hash = "sha256-+1jdbKsDKmG+bf0NEVYMv8t7Meuge1z2cgYfbdB9q8A="; | ||||
|   }; | ||||
|  | ||||
|   sourceRoot = "kernel_abi_check-${version}/bindings/python"; | ||||
|  | ||||
|   pyproject = true; | ||||
|  | ||||
|   nativeBuildInputs = with rustPlatform; [ cargoSetupHook maturinBuildHook ]; | ||||
| } | ||||
| @ -1,6 +1,6 @@ | ||||
| [project] | ||||
| name = "kernels" | ||||
| version = "0.9.0.dev0" | ||||
| version = "0.10.2.dev0" | ||||
| description = "Download compute kernels" | ||||
| authors = [ | ||||
|   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||
| @ -12,7 +12,7 @@ license = { text = "Apache-2.0" } | ||||
| readme = "README.md" | ||||
| requires-python = ">= 3.9" | ||||
| dependencies = [ | ||||
|   "huggingface_hub>=0.26.0,<1.0", | ||||
|   "huggingface_hub>=0.26.0,<2.0", | ||||
|   "packaging>=20.0", | ||||
|   "pyyaml>=6", | ||||
|   "tomli>=2.0; python_version<'3.11'", | ||||
| @ -34,6 +34,7 @@ dev = [ | ||||
| ] | ||||
|  | ||||
| [project.optional-dependencies] | ||||
| abi-check = ["kernel-abi-check>=0.6.2,<0.7.0"] | ||||
| torch = ["torch"] | ||||
| docs = [ | ||||
|   "hf-doc-builder", | ||||
| @ -45,6 +46,9 @@ kernels = "kernels.cli:main" | ||||
| [project.entry-points."egg_info.writers"] | ||||
| "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||
|  | ||||
| [tool.isort] | ||||
| profile = "black" | ||||
| line_length = 119 | ||||
|  | ||||
| [tool.ruff] | ||||
| exclude = [ | ||||
| @ -71,4 +75,4 @@ line-length = 119 | ||||
| # Ignored rules: | ||||
| # "E501" -> line length violation | ||||
| lint.ignore = ["E501"] | ||||
| lint.select = ["E", "F", "I", "W"] | ||||
| lint.select = ["E", "F", "W"] | ||||
|  | ||||
| @ -3,3 +3,7 @@ markers = | ||||
|     cuda_only: marks tests that should only hosts with CUDA GPUs | ||||
|     rocm_only: marks tests that should only run on hosts with ROCm GPUs | ||||
|     darwin_only: marks tests that should only run on macOS | ||||
|     xpu_only: marks tests that should only run on hosts with Intel XPUs | ||||
|     npu_only: marks tests that should only run on Ascend NPUs | ||||
|     token: enable tests that require a write token | ||||
|     is_staging_test: Marks tests that should only run on a staging environment | ||||
|  | ||||
| @ -1,3 +1,7 @@ | ||||
| import importlib.metadata | ||||
|  | ||||
| __version__ = importlib.metadata.version("kernels") | ||||
|  | ||||
| from kernels.layer import ( | ||||
|     CUDAProperties, | ||||
|     Device, | ||||
| @ -21,6 +25,7 @@ from kernels.utils import ( | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     "__version__", | ||||
|     "CUDAProperties", | ||||
|     "Device", | ||||
|     "LayerRepository", | ||||
|  | ||||
							
								
								
									
										141
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								src/kernels/check.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | ||||
| from pathlib import Path | ||||
| import sys | ||||
|  | ||||
| from huggingface_hub import snapshot_download | ||||
| from kernels.utils import CACHE_DIR | ||||
| from kernel_abi_check import ( | ||||
|     BinaryFormat, | ||||
|     IncompatibleMacOSVersion, | ||||
|     ObjectFile, | ||||
|     IncompatibleAbi3Symbol, | ||||
|     NonAbi3Symbol, | ||||
|     IncompatibleManylinuxSymbol, | ||||
|     MissingMacOSVersion, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def check_kernel( | ||||
|     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||
| ): | ||||
|     variants_path = ( | ||||
|         Path( | ||||
|             snapshot_download( | ||||
|                 repo_id, | ||||
|                 allow_patterns=["build/*"], | ||||
|                 cache_dir=CACHE_DIR, | ||||
|                 revision=revision, | ||||
|             ) | ||||
|         ) | ||||
|         / "build" | ||||
|     ) | ||||
|  | ||||
|     has_issues = False | ||||
|     for variant_path in variants_path.iterdir(): | ||||
|         if not variant_path.is_dir(): | ||||
|             print( | ||||
|                 f"⛔ `build/` must only contain directories, found: {variant_path.name}", | ||||
|                 file=sys.stderr, | ||||
|             ) | ||||
|             has_issues = True | ||||
|             continue | ||||
|  | ||||
|         print(f"Checking variant: {variant_path.name}", file=sys.stderr) | ||||
|  | ||||
|         indent = 2 | ||||
|  | ||||
|         for dylib_path in variant_path.rglob("*.so"): | ||||
|             print_with_indent( | ||||
|                 indent, | ||||
|                 f"Dynamic library {dylib_path.relative_to(variant_path)}:", | ||||
|             ) | ||||
|  | ||||
|             o = ObjectFile(dylib_path) | ||||
|             has_issues |= check_abi3(o, python_abi, indent + 2) | ||||
|  | ||||
|             # TODO: also check operating system | ||||
|             if o.format() == BinaryFormat.ELF: | ||||
|                 has_issues |= check_manylinux(o, manylinux, indent + 2) | ||||
|             elif o.format() == BinaryFormat.MACH_O: | ||||
|                 has_issues |= check_macos(o, macos, indent + 2) | ||||
|  | ||||
|     if has_issues: | ||||
|         sys.exit(1) | ||||
|  | ||||
|  | ||||
| def check_abi3(object_file: ObjectFile, python_abi: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_python_abi(python_abi) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found symbols that are incompatible with Python ABI {python_abi}:", | ||||
|         ) | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, IncompatibleAbi3Symbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}: {violation.version_added}", | ||||
|                 ) | ||||
|             elif isinstance(violation, NonAbi3Symbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🐍 Python ABI {python_abi} compatible") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def check_macos(object_file: ObjectFile, macos: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_macos(macos) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found incompatibility with macOS {macos}:", | ||||
|         ) | ||||
|  | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, MissingMacOSVersion): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     "shared library does not contain macOS version", | ||||
|                 ) | ||||
|             elif isinstance(violation, IncompatibleMacOSVersion): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"shared library requires macOS {violation.version}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🍏 compatible with macOS {macos}") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def check_manylinux(object_file: ObjectFile, manylinux: str, indent: int) -> bool: | ||||
|     has_issues = False | ||||
|     violations = object_file.check_manylinux(manylinux) | ||||
|     if violations != []: | ||||
|         has_issues = True | ||||
|         print_with_indent( | ||||
|             indent, | ||||
|             f"⛔ Found symbols that are incompatible with {manylinux}:", | ||||
|         ) | ||||
|  | ||||
|         for violation in violations: | ||||
|             if isinstance(violation, IncompatibleManylinuxSymbol): | ||||
|                 print_with_indent( | ||||
|                     indent + 3, | ||||
|                     f"{violation.name}_{violation.dep}: {violation.version}", | ||||
|                 ) | ||||
|     else: | ||||
|         print_with_indent(indent, f"🐧 {manylinux} compatible") | ||||
|  | ||||
|     return has_issues | ||||
|  | ||||
|  | ||||
| def print_with_indent(indent: int, message: str): | ||||
|     print(f"{' ' * indent}{message}", file=sys.stderr) | ||||
| @ -4,6 +4,8 @@ import json | ||||
| import sys | ||||
| from pathlib import Path | ||||
|  | ||||
| from huggingface_hub import create_repo, upload_folder | ||||
|  | ||||
| from kernels.compat import tomllib | ||||
| from kernels.lockfile import KernelLock, get_kernel_locks | ||||
| from kernels.utils import install_kernel, install_kernel_all_variants | ||||
| @ -18,6 +20,31 @@ def main(): | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(required=True) | ||||
|  | ||||
|     check_parser = subparsers.add_parser("check", help="Check a kernel for compliance") | ||||
|     check_parser.add_argument("repo_id", type=str, help="The kernel repo ID") | ||||
|     check_parser.add_argument( | ||||
|         "--revision", | ||||
|         type=str, | ||||
|         default="main", | ||||
|         help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')", | ||||
|     ) | ||||
|     check_parser.add_argument("--macos", type=str, help="macOS version", default="15.0") | ||||
|     check_parser.add_argument( | ||||
|         "--manylinux", type=str, help="Manylinux version", default="manylinux_2_28" | ||||
|     ) | ||||
|     check_parser.add_argument( | ||||
|         "--python-abi", type=str, help="Python ABI version", default="3.9" | ||||
|     ) | ||||
|     check_parser.set_defaults( | ||||
|         func=lambda args: check_kernel( | ||||
|             macos=args.macos, | ||||
|             manylinux=args.manylinux, | ||||
|             python_abi=args.python_abi, | ||||
|             repo_id=args.repo_id, | ||||
|             revision=args.revision, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     download_parser = subparsers.add_parser("download", help="Download locked kernels") | ||||
|     download_parser.add_argument( | ||||
|         "project_dir", | ||||
| @ -31,6 +58,24 @@ def main(): | ||||
|     ) | ||||
|     download_parser.set_defaults(func=download_kernels) | ||||
|  | ||||
|     upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub") | ||||
|     upload_parser.add_argument( | ||||
|         "kernel_dir", | ||||
|         type=Path, | ||||
|         help="Directory of the kernel build", | ||||
|     ) | ||||
|     upload_parser.add_argument( | ||||
|         "--repo_id", | ||||
|         type=str, | ||||
|         help="Repository ID to use to upload to the Hugging Face Hub", | ||||
|     ) | ||||
|     upload_parser.add_argument( | ||||
|         "--private", | ||||
|         action="store_true", | ||||
|         help="If the repository should be private.", | ||||
|     ) | ||||
|     upload_parser.set_defaults(func=upload_kernels) | ||||
|  | ||||
|     lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") | ||||
|     lock_parser.add_argument( | ||||
|         "project_dir", | ||||
| @ -153,8 +198,56 @@ def lock_kernels(args): | ||||
|         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||||
|  | ||||
|  | ||||
| def upload_kernels(args): | ||||
|     kernel_dir = Path(args.kernel_dir).resolve() | ||||
|     build_dir = kernel_dir / "build" | ||||
|     if not kernel_dir.is_dir(): | ||||
|         raise ValueError(f"{kernel_dir} is not a directory") | ||||
|     if not build_dir.is_dir(): | ||||
|         raise ValueError("Couldn't find `build` directory inside `kernel_dir`") | ||||
|  | ||||
|     repo_id = create_repo( | ||||
|         repo_id=args.repo_id, private=args.private, exist_ok=True | ||||
|     ).repo_id | ||||
|  | ||||
|     delete_patterns: set[str] = set() | ||||
|     for build_variant in build_dir.iterdir(): | ||||
|         if build_variant.is_dir(): | ||||
|             delete_patterns.add(f"{build_variant.name}/**") | ||||
|  | ||||
|     upload_folder( | ||||
|         repo_id=repo_id, | ||||
|         folder_path=build_dir, | ||||
|         path_in_repo="build", | ||||
|         delete_patterns=list(delete_patterns), | ||||
|         commit_message="Build uploaded using `kernels`.", | ||||
|     ) | ||||
|     print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.") | ||||
|  | ||||
|  | ||||
| class _JSONEncoder(json.JSONEncoder): | ||||
|     def default(self, o): | ||||
|         if dataclasses.is_dataclass(o): | ||||
|             return dataclasses.asdict(o) | ||||
|         return super().default(o) | ||||
|  | ||||
|  | ||||
| def check_kernel( | ||||
|     *, macos: str, manylinux: str, python_abi: str, repo_id: str, revision: str | ||||
| ): | ||||
|     try: | ||||
|         import kernels.check | ||||
|     except ImportError: | ||||
|         print( | ||||
|             "`kernels check` requires the `kernel-abi-check` package: pip install kernel-abi-check", | ||||
|             file=sys.stderr, | ||||
|         ) | ||||
|         sys.exit(1) | ||||
|  | ||||
|     kernels.check.check_kernel( | ||||
|         macos=macos, | ||||
|         manylinux=manylinux, | ||||
|         python_abi=python_abi, | ||||
|         repo_id=repo_id, | ||||
|         revision=revision, | ||||
|     ) | ||||
|  | ||||
| @ -87,7 +87,7 @@ class Device: | ||||
|  | ||||
|     Args: | ||||
|         type (`str`): | ||||
|             The device type (e.g., "cuda", "mps", "cpu"). | ||||
|             The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu"). | ||||
|         properties ([`CUDAProperties`], *optional*): | ||||
|             Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. | ||||
|  | ||||
| @ -106,6 +106,12 @@ class Device: | ||||
|  | ||||
|         # MPS device for Apple Silicon | ||||
|         mps_device = Device(type="mps") | ||||
|  | ||||
|         # XPU device (e.g., Intel(R) Data Center GPU Max 1550) | ||||
|         xpu_device = Device(type="xpu") | ||||
|  | ||||
|         # NPU device (Huawei Ascend) | ||||
|         npu_device = Device(type="npu") | ||||
|         ``` | ||||
|     """ | ||||
|  | ||||
| @ -125,6 +131,10 @@ class Device: | ||||
|             return _ROCMRepos() | ||||
|         elif self.type == "mps": | ||||
|             return _MPSRepos() | ||||
|         elif self.type == "xpu": | ||||
|             return _XPURepos() | ||||
|         elif self.type == "npu": | ||||
|             return _NPURepos() | ||||
|         else: | ||||
|             raise ValueError(f"Unknown device type: {self.type}") | ||||
|  | ||||
| @ -311,7 +321,7 @@ class LayerRepository: | ||||
|         return hash((self.layer_name, self._repo_id, self._revision, self._version)) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`" | ||||
|         return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`" | ||||
|  | ||||
|  | ||||
| class LocalLayerRepository: | ||||
| @ -367,7 +377,7 @@ class LocalLayerRepository: | ||||
|         return hash((self.layer_name, self._repo_path, self._package_name)) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"`{self._repo_path}` (package: {self._package_name}) for layer `{self.layer_name}`" | ||||
|         return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`" | ||||
|  | ||||
|  | ||||
| class LockedLayerRepository: | ||||
| @ -422,7 +432,7 @@ class LockedLayerRepository: | ||||
|         return hash((self.layer_name, self._repo_id)) | ||||
|  | ||||
|     def __str__(self) -> str: | ||||
|         return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`" | ||||
|         return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`" | ||||
|  | ||||
|  | ||||
| _CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {} | ||||
| @ -447,6 +457,46 @@ class _DeviceRepos(ABC): | ||||
|         ... | ||||
|  | ||||
|  | ||||
| class _XPURepos(_DeviceRepos): | ||||
|     _repos: Dict[Mode, LayerRepositoryProtocol] | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._repos = {} | ||||
|  | ||||
|     @property | ||||
|     def repos( | ||||
|         self, | ||||
|     ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: | ||||
|         return self._repos | ||||
|  | ||||
|     def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): | ||||
|         if device.type != "xpu": | ||||
|             raise ValueError(f"Device type must be 'xpu', got {device.type}") | ||||
|  | ||||
|         self._repos = repos | ||||
|  | ||||
|  | ||||
| class _NPURepos(_DeviceRepos): | ||||
|     _repos: Dict[Mode, LayerRepositoryProtocol] | ||||
|  | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
|         self._repos = {} | ||||
|  | ||||
|     @property | ||||
|     def repos( | ||||
|         self, | ||||
|     ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: | ||||
|         return self._repos | ||||
|  | ||||
|     def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): | ||||
|         if device.type != "npu": | ||||
|             raise ValueError(f"Device type must be 'npu', got {device.type}") | ||||
|  | ||||
|         self._repos = repos | ||||
|  | ||||
|  | ||||
| class _MPSRepos(_DeviceRepos): | ||||
|     _repos: Dict[Mode, LayerRepositoryProtocol] | ||||
|  | ||||
| @ -531,7 +581,7 @@ class _ROCMRepos(_DeviceRepos): | ||||
|  | ||||
| def _validate_device_type(device_type: str) -> None: | ||||
|     """Validate that the device type is supported.""" | ||||
|     supported_devices = {"cuda", "rocm", "mps", "cpu"} | ||||
|     supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"} | ||||
|     if device_type not in supported_devices: | ||||
|         raise ValueError( | ||||
|             f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" | ||||
| @ -578,7 +628,7 @@ def use_kernel_mapping( | ||||
|  | ||||
|         from kernels import use_kernel_forward_from_hub | ||||
|         from kernels import use_kernel_mapping, LayerRepository, Device | ||||
|         from kernels import kernelize | ||||
|         from kernels import Mode, kernelize | ||||
|  | ||||
|         # Define a mapping | ||||
|         mapping = { | ||||
| @ -601,7 +651,7 @@ def use_kernel_mapping( | ||||
|         # Use the mapping for the duration of the context. | ||||
|         with use_kernel_mapping(mapping): | ||||
|             # kernelize uses the temporary mapping | ||||
|             model = kernelize(model, device="cuda") | ||||
|             model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda") | ||||
|  | ||||
|         # Outside the context, original mappings are restored | ||||
|         ``` | ||||
| @ -772,7 +822,7 @@ def _select_repository( | ||||
| def kernelize( | ||||
|     model: "nn.Module", | ||||
|     *, | ||||
|     mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE, | ||||
|     mode: Mode, | ||||
|     device: Optional[Union[str, "torch.device"]] = None, | ||||
|     use_fallback: bool = True, | ||||
| ): | ||||
| @ -785,11 +835,11 @@ def kernelize( | ||||
|     Args: | ||||
|         model (`nn.Module`): | ||||
|             The PyTorch model to kernelize. | ||||
|         mode ([`Mode`], *optional*, defaults to `Mode.TRAINING | Mode.TORCH_COMPILE`): | ||||
|             The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE` | ||||
|             kernelizes the model for training with `torch.compile`. | ||||
|         mode ([`Mode`]): The mode that the kernel is going to be used in. For example, | ||||
|             `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with | ||||
|             `torch.compile`. | ||||
|         device (`Union[str, torch.device]`, *optional*): | ||||
|             The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu". | ||||
|             The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu". | ||||
|             The device type will be inferred from the model parameters when not provided. | ||||
|         use_fallback (`bool`, *optional*, defaults to `True`): | ||||
|             Whether to use the original forward method of modules when no compatible kernel could be found. | ||||
| @ -813,7 +863,7 @@ def kernelize( | ||||
|                 return F.silu(x[..., :d]) * x[..., d:] | ||||
|  | ||||
|         mapping = { | ||||
|             "LayerNorm": { | ||||
|             "SiluAndMul": { | ||||
|                 "cuda": LayerRepository( | ||||
|                     repo_id="kernels-community/activation", | ||||
|                     layer_name="SiluAndMul", | ||||
| @ -829,7 +879,7 @@ def kernelize( | ||||
|         ) | ||||
|  | ||||
|         # Kernelize for inference | ||||
|         kernelized_model = kernelize(model) | ||||
|         kernelized_model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
|         ``` | ||||
|     """ | ||||
|  | ||||
| @ -954,7 +1004,8 @@ def use_kernel_forward_from_hub(layer_name: str): | ||||
|         import torch | ||||
|         import torch.nn as nn | ||||
|  | ||||
|         from kernels import use_kernel_forward_from_hub, kernelize | ||||
|         from kernels import use_kernel_forward_from_hub | ||||
|         from kernels import Mode, kernelize | ||||
|  | ||||
|         @use_kernel_forward_from_hub("MyCustomLayer") | ||||
|         class MyCustomLayer(nn.Module): | ||||
| @ -969,7 +1020,7 @@ def use_kernel_forward_from_hub(layer_name: str): | ||||
|         model = MyCustomLayer(768) | ||||
|  | ||||
|         # The layer can now be kernelized: | ||||
|         # model = kernelize(model, device="cuda") | ||||
|         # model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda") | ||||
|         ``` | ||||
|     """ | ||||
|  | ||||
| @ -994,7 +1045,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]: | ||||
|     return layer | ||||
|  | ||||
|  | ||||
| def _validate_layer(*, check_cls, cls): | ||||
| def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol): | ||||
|     import torch.nn as nn | ||||
|  | ||||
|     # The layer must have at least have the following properties: (1) it | ||||
| @ -1003,12 +1054,12 @@ def _validate_layer(*, check_cls, cls): | ||||
|     # methods. | ||||
|  | ||||
|     if not issubclass(cls, nn.Module): | ||||
|         raise TypeError(f"Layer `{cls}` is not a Torch layer.") | ||||
|         raise TypeError(f"Layer `{cls.__name__}` is not a Torch layer.") | ||||
|  | ||||
|     # We verify statelessness by checking that the does not have its own | ||||
|     # constructor (since the constructor could add member variables)... | ||||
|     if cls.__init__ is not nn.Module.__init__: | ||||
|         raise TypeError("Layer must not override nn.Module constructor.") | ||||
|         raise TypeError(f"{repo} must not override nn.Module constructor.") | ||||
|  | ||||
|     # ... or predefined member variables. | ||||
|     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} | ||||
| @ -1016,7 +1067,9 @@ def _validate_layer(*, check_cls, cls): | ||||
|     difference = cls_members - torch_module_members | ||||
|     # verify if : difference ⊄ {"can_torch_compile", "has_backward"} | ||||
|     if not difference <= {"can_torch_compile", "has_backward"}: | ||||
|         raise TypeError("Layer must not contain additional members.") | ||||
|         raise TypeError( | ||||
|             f"{repo} must not contain additional members compared to `{check_cls.__name__}`." | ||||
|         ) | ||||
|  | ||||
|     # Check whether the forward signatures are similar. | ||||
|     params = inspect.signature(cls.forward).parameters | ||||
| @ -1024,13 +1077,13 @@ def _validate_layer(*, check_cls, cls): | ||||
|  | ||||
|     if len(params) != len(ref_params): | ||||
|         raise TypeError( | ||||
|             "Forward signature does not match: different number of arguments." | ||||
|             f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments." | ||||
|         ) | ||||
|  | ||||
|     for param, ref_param in zip(params.values(), ref_params.values()): | ||||
|         if param.kind != ref_param.kind: | ||||
|             raise TypeError( | ||||
|                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||
|                 f"Forward signature of {repo} does not match `{check_cls.__name__}`: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||
|             ) | ||||
|  | ||||
|  | ||||
| @ -1147,7 +1200,7 @@ def _get_layer_memoize( | ||||
|         return layer | ||||
|  | ||||
|     layer = _get_kernel_layer(repo) | ||||
|     _validate_layer(check_cls=module_class, cls=layer) | ||||
|     _validate_layer(check_cls=module_class, cls=layer, repo=repo) | ||||
|     _CACHED_LAYER[repo] = layer | ||||
|  | ||||
|     return layer | ||||
|  | ||||
| @ -35,6 +35,14 @@ def _get_cache_dir() -> Optional[str]: | ||||
| CACHE_DIR: Optional[str] = _get_cache_dir() | ||||
|  | ||||
|  | ||||
| def _get_privateuse_backend_name() -> Optional[str]: | ||||
|     import torch | ||||
|  | ||||
|     if hasattr(torch._C, "_get_privateuse1_backend_name"): | ||||
|         return torch._C._get_privateuse1_backend_name() | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def build_variant() -> str: | ||||
|     import torch | ||||
|  | ||||
| @ -46,11 +54,17 @@ def build_variant() -> str: | ||||
|         compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}" | ||||
|     elif torch.backends.mps.is_available(): | ||||
|         compute_framework = "metal" | ||||
|     elif hasattr(torch, "xpu") and torch.xpu.is_available(): | ||||
|         compute_framework = "xpu" | ||||
|     elif torch.version.xpu is not None: | ||||
|         version = torch.version.xpu | ||||
|         compute_framework = f"xpu{version[0:4]}{version[5:6]}" | ||||
|     elif _get_privateuse_backend_name() == "npu": | ||||
|         from torch_npu.utils.collect_env import get_cann_version  # type: ignore[import-not-found] | ||||
|  | ||||
|         cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] | ||||
|         compute_framework = f"cann{cann_major}{cann_minor}" | ||||
|     else: | ||||
|         raise AssertionError( | ||||
|             "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled." | ||||
|             "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." | ||||
|         ) | ||||
|  | ||||
|     torch_version = parse(torch.__version__) | ||||
| @ -71,24 +85,6 @@ def universal_build_variant() -> str: | ||||
|     return "torch-universal" | ||||
|  | ||||
|  | ||||
| # Metaclass to allow overriding the `__repr__` method for kernel modules. | ||||
| class _KernelModuleMeta(type): | ||||
|     def __repr__(self): | ||||
|         return "<class 'kernel_module'>" | ||||
|  | ||||
|  | ||||
| # Custom module type to identify dynamically loaded kernel modules. | ||||
| # Using a subclass lets us distinguish these from regular imports. | ||||
| class _KernelModuleType(ModuleType, metaclass=_KernelModuleMeta): | ||||
|     """Marker class for modules loaded dynamically from a path.""" | ||||
|  | ||||
|     module_name: str | ||||
|     is_kernel: bool = True | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return f"<kernel_module '{self.module_name}' from '{self.__file__}'>" | ||||
|  | ||||
|  | ||||
| def import_from_path(module_name: str, file_path: Path) -> ModuleType: | ||||
|     # We cannot use the module name as-is, after adding it to `sys.modules`, | ||||
|     # it would also be used for other imports. So, we make a module name that | ||||
| @ -102,9 +98,6 @@ def import_from_path(module_name: str, file_path: Path) -> ModuleType: | ||||
|     module = importlib.util.module_from_spec(spec) | ||||
|     if module is None: | ||||
|         raise ImportError(f"Cannot load module {module_name} from spec") | ||||
|     module.__class__ = _KernelModuleType | ||||
|     assert isinstance(module, _KernelModuleType)  # for mypy type checking | ||||
|     module.module_name = module_name | ||||
|     sys.modules[module_name] = module | ||||
|     spec.loader.exec_module(module)  # type: ignore | ||||
|     return module | ||||
| @ -269,8 +262,24 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: | ||||
|     Returns: | ||||
|         `ModuleType`: The imported kernel module. | ||||
|     """ | ||||
|     package_name, package_path = _load_kernel_from_path(repo_path, package_name) | ||||
|     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||
|     variant = build_variant() | ||||
|     universal_variant = universal_build_variant() | ||||
|  | ||||
|     # Presume we were given the top level path of the kernel repository. | ||||
|     for base_path in [repo_path, repo_path / "build"]: | ||||
|         # Prefer the universal variant if it exists. | ||||
|         for v in [universal_variant, variant]: | ||||
|             package_path = base_path / v / package_name / "__init__.py" | ||||
|             if package_path.exists(): | ||||
|                 return import_from_path(package_name, package_path) | ||||
|  | ||||
|     # If we didn't find the package in the repo we may have a explicit | ||||
|     # package path. | ||||
|     package_path = repo_path / package_name / "__init__.py" | ||||
|     if package_path.exists(): | ||||
|         return import_from_path(package_name, package_path) | ||||
|  | ||||
|     raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}") | ||||
|  | ||||
|  | ||||
| def has_kernel( | ||||
|  | ||||
| @ -3,6 +3,8 @@ import sys | ||||
| import pytest | ||||
| import torch | ||||
|  | ||||
| from kernels.utils import _get_privateuse_backend_name | ||||
|  | ||||
| has_cuda = ( | ||||
|     hasattr(torch.version, "cuda") | ||||
|     and torch.version.cuda is not None | ||||
| @ -13,6 +15,20 @@ has_rocm = ( | ||||
|     and torch.version.hip is not None | ||||
|     and torch.cuda.device_count() > 0 | ||||
| ) | ||||
| has_xpu = ( | ||||
|     hasattr(torch.version, "xpu") | ||||
|     and torch.version.xpu is not None | ||||
|     and torch.xpu.device_count() > 0 | ||||
| ) | ||||
| has_npu = _get_privateuse_backend_name() == "npu" | ||||
|  | ||||
|  | ||||
| def pytest_addoption(parser): | ||||
|     parser.addoption( | ||||
|         "--token", | ||||
|         action="store_true", | ||||
|         help="run tests that require a token with write permissions", | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def pytest_runtest_setup(item): | ||||
| @ -22,3 +38,9 @@ def pytest_runtest_setup(item): | ||||
|         pytest.skip("skipping ROCm-only test on host without ROCm") | ||||
|     if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): | ||||
|         pytest.skip("skipping macOS-only test on non-macOS platform") | ||||
|     if "xpu_only" in item.keywords and not has_xpu: | ||||
|         pytest.skip("skipping XPU-only test on host without XPU") | ||||
|     if "npu_only" in item.keywords and not has_npu: | ||||
|         pytest.skip("skipping NPU-only test on host without NPU") | ||||
|     if "token" in item.keywords and not item.config.getoption("--token"): | ||||
|         pytest.skip("need --token option to run this test") | ||||
|  | ||||
| @ -10,10 +10,16 @@ def kernel(): | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def local_kernel(): | ||||
| def local_kernel_path(): | ||||
|     package_name, path = install_kernel("kernels-community/activation", "main") | ||||
|     # Path is the build variant path (build/torch-<...>), so the grandparent | ||||
|     # is the kernel repository path. | ||||
|     return package_name, path | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def local_kernel(local_kernel_path): | ||||
|     package_name, path = local_kernel_path | ||||
|     return get_local_kernel(path.parent.parent, package_name) | ||||
|  | ||||
|  | ||||
| @ -66,6 +72,39 @@ def test_local_kernel(local_kernel, device): | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_local_kernel_path_types(local_kernel_path, device): | ||||
|     package_name, path = local_kernel_path | ||||
|  | ||||
|     # Top-level repo path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071 | ||||
|     kernel = get_local_kernel(path.parent.parent, package_name) | ||||
|     x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3) | ||||
|     y = torch.empty_like(x) | ||||
|  | ||||
|     kernel.gelu_fast(y, x) | ||||
|     expected = torch.tensor( | ||||
|         [[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]], | ||||
|         device=device, | ||||
|         dtype=torch.float16, | ||||
|     ) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|     # Build directory path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build | ||||
|     kernel = get_local_kernel(path.parent.parent / "build", package_name) | ||||
|     y = torch.empty_like(x) | ||||
|     kernel.gelu_fast(y, x) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|     # Explicit package path | ||||
|     # ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux | ||||
|     kernel = get_local_kernel(path, package_name) | ||||
|     y = torch.empty_like(x) | ||||
|     kernel.gelu_fast(y, x) | ||||
|     assert torch.allclose(y, expected) | ||||
|  | ||||
|  | ||||
| @pytest.mark.darwin_only | ||||
| @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) | ||||
| def test_relu_metal(metal_kernel, dtype): | ||||
|  | ||||
| @ -35,6 +35,7 @@ def test_load_locked(): | ||||
|     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| def test_layer_locked(): | ||||
|     project_dir = Path(__file__).parent / "layer_locking" | ||||
|  | ||||
|  | ||||
							
								
								
									
										115
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								tests/test_kernel_upload.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,115 @@ | ||||
| import logging | ||||
| import os | ||||
| import re | ||||
| import tempfile | ||||
| from dataclasses import dataclass | ||||
| from pathlib import Path | ||||
| from typing import List | ||||
|  | ||||
| import pytest | ||||
| from huggingface_hub import delete_repo, model_info | ||||
|  | ||||
| from kernels.cli import upload_kernels | ||||
|  | ||||
| REPO_ID = "valid_org/kernels-upload-test" | ||||
|  | ||||
|  | ||||
| PY_CONTENT = """\ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| def main(): | ||||
|     print("Hello from torch-universal!") | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| """ | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class UploadArgs: | ||||
|     kernel_dir: None | ||||
|     repo_id: None | ||||
|     private: False | ||||
|  | ||||
|  | ||||
| def next_filename(path: Path) -> Path: | ||||
|     """ | ||||
|     Given a path like foo_2050.py, return foo_2051.py. | ||||
|     """ | ||||
|     m = re.match(r"^(.*?)(\d+)(\.py)$", path.name) | ||||
|     if not m: | ||||
|         raise ValueError( | ||||
|             f"Filename {path.name!r} does not match pattern <prefix>_<number>.py" | ||||
|         ) | ||||
|  | ||||
|     prefix, number, suffix = m.groups() | ||||
|     new_number = str(int(number) + 1).zfill(len(number)) | ||||
|     return path.with_name(f"{prefix}{new_number}{suffix}") | ||||
|  | ||||
|  | ||||
| def get_filename_to_change(repo_filenames): | ||||
|     for f in repo_filenames: | ||||
|         if "foo" in f and f.endswith(".py"): | ||||
|             filename_to_change = os.path.basename(f) | ||||
|             break | ||||
|     assert filename_to_change | ||||
|     return filename_to_change | ||||
|  | ||||
|  | ||||
| def get_filenames_from_a_repo(repo_id: str) -> List[str]: | ||||
|     try: | ||||
|         repo_info = model_info(repo_id=repo_id, files_metadata=True) | ||||
|         repo_siblings = repo_info.siblings | ||||
|         if repo_siblings is not None: | ||||
|             return [f.rfilename for f in repo_siblings] | ||||
|         else: | ||||
|             raise ValueError("No repo siblings found.") | ||||
|     except Exception as e: | ||||
|         logging.error(f"Error connecting to the Hub: {e}.") | ||||
|  | ||||
|  | ||||
| @pytest.mark.token | ||||
| @pytest.mark.is_staging_test | ||||
| def test_kernel_upload_works_as_expected(): | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         script_path = build_dir / "foo.py" | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     assert any(str(script_path.name) for f in repo_filenames) | ||||
|     delete_repo(repo_id=REPO_ID) | ||||
|  | ||||
|  | ||||
| @pytest.mark.token | ||||
| @pytest.mark.is_staging_test | ||||
| def test_kernel_upload_deletes_as_expected(): | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         script_path = build_dir / "foo_2025.py" | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     filename_to_change = get_filename_to_change(repo_filenames) | ||||
|  | ||||
|     with tempfile.TemporaryDirectory() as tmpdir: | ||||
|         path = f"{tmpdir}/build/torch-universal/upload_test" | ||||
|         build_dir = Path(path) | ||||
|         build_dir.mkdir(parents=True, exist_ok=True) | ||||
|         changed_filename = next_filename(Path(filename_to_change)) | ||||
|         script_path = build_dir / changed_filename | ||||
|         script_path.write_text(PY_CONTENT) | ||||
|         upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) | ||||
|  | ||||
|     repo_filenames = get_filenames_from_a_repo(REPO_ID) | ||||
|     assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}" | ||||
|     assert not any( | ||||
|         str(filename_to_change) in k for k in repo_filenames | ||||
|     ), f"{repo_filenames=}" | ||||
|     delete_repo(repo_id=REPO_ID) | ||||
| @ -21,14 +21,21 @@ from kernels.layer import ( | ||||
|     _KERNEL_MAPPING, | ||||
|     _validate_layer, | ||||
| ) | ||||
| from kernels.utils import install_kernel | ||||
| from kernels.utils import ( | ||||
|     _get_privateuse_backend_name, | ||||
|     install_kernel, | ||||
| ) | ||||
|  | ||||
| kernel_layer_mapping = { | ||||
|     "SiluAndMul": { | ||||
|         Device(type="cuda"): LayerRepository( | ||||
|             repo_id="kernels-community/activation", | ||||
|             layer_name="SiluAndMul", | ||||
|         ) | ||||
|         ), | ||||
|         "npu": LayerRepository( | ||||
|             repo_id="kernels-ext-npu/SwiGlu", | ||||
|             layer_name="SwiGlu", | ||||
|         ), | ||||
|     }, | ||||
|     "SiluAndMulNoCompile": { | ||||
|         "cuda": LayerRepository( | ||||
| @ -46,11 +53,37 @@ kernel_layer_mapping = { | ||||
|             layer_name="SiluAndMul", | ||||
|         ) | ||||
|     }, | ||||
|     "LigerRMSNorm": { | ||||
|         "xpu": LayerRepository( | ||||
|             repo_id="kernels-community/liger_kernels", | ||||
|             layer_name="LigerRMSNorm",  # Triton | ||||
|         ) | ||||
|     }, | ||||
| } | ||||
|  | ||||
| register_kernel_mapping(kernel_layer_mapping) | ||||
|  | ||||
|  | ||||
| class RMSNorm(nn.Module): | ||||
|     def __init__(self, weight: torch.Tensor, eps: float = 1e-6): | ||||
|         super().__init__() | ||||
|         # Used to check that we called hub kernel. | ||||
|         self.n_calls = 0 | ||||
|         self.weight = nn.Parameter(weight) | ||||
|         self.variance_epsilon = eps | ||||
|  | ||||
|     def forward(self, x: torch.Tensor): | ||||
|         self.n_calls += 1 | ||||
|         var = x.pow(2).mean(-1, keepdim=True) | ||||
|         x_norm = x * torch.rsqrt(var + self.variance_epsilon) | ||||
|         return x_norm * self.weight | ||||
|  | ||||
|  | ||||
| @use_kernel_forward_from_hub("LigerRMSNorm") | ||||
| class RMSNormWithKernel(RMSNorm): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class SiluAndMul(nn.Module): | ||||
|     def __init__(self): | ||||
|         super().__init__() | ||||
| @ -90,6 +123,18 @@ class TorchLinearWithCounter(nn.Linear): | ||||
|         return super().forward(input) | ||||
|  | ||||
|  | ||||
| @pytest.fixture | ||||
| def device(): | ||||
|     if torch.cuda.is_available(): | ||||
|         return "cuda" | ||||
|     elif hasattr(torch, "xpu") and torch.xpu.is_available(): | ||||
|         return "xpu" | ||||
|     elif _get_privateuse_backend_name() == "npu": | ||||
|         return "npu" | ||||
|  | ||||
|     pytest.skip("No CUDA, NPU or XPU") | ||||
|  | ||||
|  | ||||
| def test_arg_kinds(): | ||||
|     @use_kernel_forward_from_hub("ArgKind") | ||||
|     class ArgKind(nn.Module): | ||||
| @ -110,24 +155,20 @@ def test_arg_kinds(): | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | ||||
| @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||
| def test_hub_forward(cls, device): | ||||
| def test_hub_forward(cls): | ||||
|     torch.random.manual_seed(0) | ||||
|  | ||||
|     silu_and_mul = SiluAndMul() | ||||
|     X = torch.randn((32, 64), device=device) | ||||
|     X = torch.randn((32, 64), device="cuda") | ||||
|     Y = silu_and_mul(X) | ||||
|  | ||||
|     silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE) | ||||
|     silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE) | ||||
|     Y_kernel = silu_and_mul_with_kernel(X) | ||||
|  | ||||
|     torch.testing.assert_close(Y_kernel, Y) | ||||
|  | ||||
|     assert silu_and_mul.n_calls == 1 | ||||
|     if device == "cuda": | ||||
|         assert silu_and_mul_with_kernel.n_calls == 0 | ||||
|     else: | ||||
|         assert silu_and_mul_with_kernel.n_calls == 1 | ||||
|     assert silu_and_mul_with_kernel.n_calls == 0 | ||||
|  | ||||
|  | ||||
| @pytest.mark.rocm_only | ||||
| @ -151,6 +192,54 @@ def test_hub_forward_rocm(): | ||||
|     assert silu_and_mul_with_kernel.n_calls in [0, 1] | ||||
|  | ||||
|  | ||||
| @pytest.mark.xpu_only | ||||
| def test_hub_forward_xpu(): | ||||
|     torch.manual_seed(0) | ||||
|  | ||||
|     hidden_size = 1024 | ||||
|     weight = torch.ones(hidden_size, device="xpu") | ||||
|     rms_norm = RMSNorm(weight).to("xpu") | ||||
|     X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32) | ||||
|     Y = rms_norm(X) | ||||
|  | ||||
|     rms_norm_with_kernel = kernelize( | ||||
|         RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu" | ||||
|     ) | ||||
|     Y_kernel = rms_norm_with_kernel(X) | ||||
|  | ||||
|     torch.testing.assert_close(Y_kernel, Y) | ||||
|  | ||||
|     assert rms_norm.n_calls == 1 | ||||
|     assert rms_norm_with_kernel.n_calls == 0 | ||||
|  | ||||
|  | ||||
| @pytest.mark.npu_only | ||||
| def test_hub_forward_npu(): | ||||
|     torch.manual_seed(0) | ||||
|  | ||||
|     silu_and_mul = SiluAndMul() | ||||
|     X = torch.randn((32, 64), device="npu") | ||||
|     Y = silu_and_mul(X) | ||||
|  | ||||
|     silu_and_mul_with_kernel = kernelize( | ||||
|         SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE | ||||
|     ) | ||||
|     Y_kernel = silu_and_mul_with_kernel(X) | ||||
|  | ||||
|     torch.testing.assert_close(Y_kernel, Y) | ||||
|  | ||||
|     assert silu_and_mul.n_calls == 1 | ||||
|     assert silu_and_mul_with_kernel.n_calls == 0 | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
|     hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(), | ||||
|     reason="Skip on xpu devices", | ||||
| ) | ||||
| @pytest.mark.skipif( | ||||
|     _get_privateuse_backend_name() == "npu", | ||||
|     reason="Skip on npu devices", | ||||
| ) | ||||
| def test_rocm_kernel_mapping(): | ||||
|     """Test that ROCm shorthand device mapping works correctly.""" | ||||
|     kernel_layer_mapping = { | ||||
| @ -238,16 +327,16 @@ def test_layer_fallback_works(): | ||||
|     kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE) | ||||
|  | ||||
|  | ||||
| def test_local_layer_repo(): | ||||
| def test_local_layer_repo(device): | ||||
|     # Fetch a kernel to the local cache. | ||||
|     package_name, path = install_kernel("kernels-test/backward-marker-test", "main") | ||||
|  | ||||
|     linear = TorchLinearWithCounter(32, 32).to("cuda") | ||||
|     linear = TorchLinearWithCounter(32, 32).to(device) | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Linear": { | ||||
|                 "cuda": LocalLayerRepository( | ||||
|                 device: LocalLayerRepository( | ||||
|                     # install_kernel will give the fully-resolved path. | ||||
|                     repo_path=path.parent.parent, | ||||
|                     package_name=package_name, | ||||
| @ -259,7 +348,7 @@ def test_local_layer_repo(): | ||||
|     ): | ||||
|         kernelize(linear, mode=Mode.INFERENCE) | ||||
|  | ||||
|     X = torch.randn(10, 32, device="cuda") | ||||
|     X = torch.randn(10, 32, device=device) | ||||
|     linear(X) | ||||
|     assert linear.n_calls == 0 | ||||
|  | ||||
| @ -327,6 +416,7 @@ def test_mapping_contexts(): | ||||
|         "SiluAndMul", | ||||
|         "SiluAndMulStringDevice", | ||||
|         "SiluAndMulNoCompile", | ||||
|         "LigerRMSNorm", | ||||
|     } | ||||
|  | ||||
|     extra_mapping1 = { | ||||
| @ -344,6 +434,7 @@ def test_mapping_contexts(): | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "LigerRMSNorm", | ||||
|             "TestKernel", | ||||
|         } | ||||
|  | ||||
| @ -362,6 +453,7 @@ def test_mapping_contexts(): | ||||
|                 "SiluAndMul", | ||||
|                 "SiluAndMulStringDevice", | ||||
|                 "SiluAndMulNoCompile", | ||||
|                 "LigerRMSNorm", | ||||
|                 "TestKernel", | ||||
|             } | ||||
|             assert ( | ||||
| @ -375,6 +467,7 @@ def test_mapping_contexts(): | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "LigerRMSNorm", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
| @ -397,6 +490,7 @@ def test_mapping_contexts(): | ||||
|             "SiluAndMul", | ||||
|             "SiluAndMulStringDevice", | ||||
|             "SiluAndMulNoCompile", | ||||
|             "LigerRMSNorm", | ||||
|             "TestKernel", | ||||
|         } | ||||
|         assert ( | ||||
| @ -408,6 +502,7 @@ def test_mapping_contexts(): | ||||
|         "SiluAndMul", | ||||
|         "SiluAndMulStringDevice", | ||||
|         "SiluAndMulNoCompile", | ||||
|         "LigerRMSNorm", | ||||
|     } | ||||
|  | ||||
|  | ||||
| @ -417,26 +512,43 @@ def test_validate_kernel_layer(): | ||||
|             super().__init__(*args, **kwargs) | ||||
|             self.foo = 42 | ||||
|  | ||||
|     with pytest.raises(TypeError, match="not override"): | ||||
|         _validate_layer(cls=BadLayer, check_cls=SiluAndMul) | ||||
|     def stub_repo(layer): | ||||
|         return LayerRepository( | ||||
|             repo_id="kernels-test/nonexisting", layer_name=layer.__name__ | ||||
|         ) | ||||
|  | ||||
|     with pytest.raises( | ||||
|         TypeError, | ||||
|         match="`kernels-test/nonexisting`.*layer `BadLayer` must not override", | ||||
|     ): | ||||
|         _validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer)) | ||||
|  | ||||
|     class BadLayer2(nn.Module): | ||||
|         foo: int = 42 | ||||
|  | ||||
|     with pytest.raises(TypeError, match="not contain additional members"): | ||||
|         _validate_layer(cls=BadLayer2, check_cls=SiluAndMul) | ||||
|     with pytest.raises( | ||||
|         TypeError, | ||||
|         match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul", | ||||
|     ): | ||||
|         _validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2)) | ||||
|  | ||||
|     class BadLayer3(nn.Module): | ||||
|         def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ... | ||||
|  | ||||
|     with pytest.raises(TypeError, match="different number of arguments"): | ||||
|         _validate_layer(cls=BadLayer3, check_cls=SiluAndMul) | ||||
|     with pytest.raises( | ||||
|         TypeError, | ||||
|         match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments", | ||||
|     ): | ||||
|         _validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3)) | ||||
|  | ||||
|     class BadLayer4(nn.Module): | ||||
|         def forward(self, *, x: torch.Tensor) -> torch.Tensor: ... | ||||
|  | ||||
|     with pytest.raises(TypeError, match="different kind of arguments"): | ||||
|         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) | ||||
|     with pytest.raises( | ||||
|         TypeError, | ||||
|         match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments", | ||||
|     ): | ||||
|         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4)) | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @ -488,11 +600,6 @@ def test_kernel_modes(): | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|         # Same as previous, since TRAINING | TORCH_COMPILE is the default. | ||||
|         kernelize(linear) | ||||
|         linear(X) | ||||
|         assert linear.n_calls == 0 | ||||
|  | ||||
|     # Case 2: register a kernel just for training. If no base kernel | ||||
|     #         layer is registered, we fall back to the original layer. | ||||
|     with use_kernel_mapping( | ||||
| @ -522,12 +629,6 @@ def test_kernel_modes(): | ||||
|         # TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original. | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         # Same as previous, since TRAINING | TORCH_COMPILE is the default. | ||||
|         kernelize(linear) | ||||
|         linear(X) | ||||
|         # TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original. | ||||
|         assert linear.n_calls == 2 | ||||
|  | ||||
|     # Case 3: register a kernel just for training and one for fallback. | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
| @ -549,23 +650,17 @@ def test_kernel_modes(): | ||||
|         X = torch.randn(10, 32, device="cuda") | ||||
|         linear(X) | ||||
|         # Falls back to TRAINING. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         kernelize(linear, mode=Mode.TRAINING) | ||||
|         linear(X) | ||||
|         # Falls back to the TRAINING kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
|         linear(X) | ||||
|         # TRAINING | TORCH_COMPILE falls back to FALLBACK kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|  | ||||
|         # Same as previous, since TRAINING | TORCH_COMPILE is the default. | ||||
|         kernelize(linear) | ||||
|         linear(X) | ||||
|         # TRAINING | TORCH_COMPILE falls back to FALLBACK kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|     # Case 4: register a kernel with two preferences. | ||||
|     with use_kernel_mapping( | ||||
| @ -585,22 +680,17 @@ def test_kernel_modes(): | ||||
|         X = torch.randn(10, 32, device="cuda") | ||||
|         linear(X) | ||||
|         # Falls back to the TRAINING | TORCH_COMPILE kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         kernelize(linear, mode=Mode.TRAINING) | ||||
|         linear(X) | ||||
|         # TRAINING can fall back to TRAINING | TORCH_COMPILE kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|         kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE) | ||||
|         linear(X) | ||||
|         # Uses TRAINING | TORCH_COMPILE kernel. | ||||
|         assert linear.n_calls == 2 | ||||
|  | ||||
|         kernelize(linear) | ||||
|         linear(X) | ||||
|         # Same as previous, since TRAINING | TORCH_COMPILE is the default. | ||||
|         assert linear.n_calls == 2 | ||||
|         assert linear.n_calls == 1 | ||||
|  | ||||
|  | ||||
| @pytest.mark.cuda_only | ||||
| @ -949,7 +1039,7 @@ def test_kernel_modes_cross_fallback(): | ||||
|         assert linear.n_calls == 2 | ||||
|  | ||||
|  | ||||
| def test_layer_versions(): | ||||
| def test_layer_versions(device): | ||||
|     @use_kernel_forward_from_hub("Version") | ||||
|     class Version(nn.Module): | ||||
|         def forward(self) -> str: | ||||
| @ -960,20 +1050,20 @@ def test_layer_versions(): | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                 Device(type=device): LayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                 ) | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|         version = kernelize(version, device=device, mode=Mode.INFERENCE) | ||||
|         assert version() == "0.2.0" | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                 Device(type=device): LayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                     version="<1.0.0", | ||||
| @ -981,13 +1071,13 @@ def test_layer_versions(): | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|         version = kernelize(version, device=device, mode=Mode.INFERENCE) | ||||
|         assert version() == "0.2.0" | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                 Device(type=device): LayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                     version="<0.2.0", | ||||
| @ -995,13 +1085,13 @@ def test_layer_versions(): | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|         version = kernelize(version, device=device, mode=Mode.INFERENCE) | ||||
|         assert version() == "0.1.1" | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                 Device(type=device): LayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                     version=">0.1.0,<0.2.0", | ||||
| @ -1009,13 +1099,13 @@ def test_layer_versions(): | ||||
|             } | ||||
|         } | ||||
|     ): | ||||
|         version = kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|         version = kernelize(version, device=device, mode=Mode.INFERENCE) | ||||
|         assert version() == "0.1.1" | ||||
|  | ||||
|     with use_kernel_mapping( | ||||
|         { | ||||
|             "Version": { | ||||
|                 Device(type="cuda"): LayerRepository( | ||||
|                 Device(type=device): LayerRepository( | ||||
|                     repo_id="kernels-test/versions", | ||||
|                     layer_name="Version", | ||||
|                     version=">0.2.0", | ||||
| @ -1024,13 +1114,13 @@ def test_layer_versions(): | ||||
|         } | ||||
|     ): | ||||
|         with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): | ||||
|             kernelize(version, device="cuda", mode=Mode.INFERENCE) | ||||
|             kernelize(version, device=device, mode=Mode.INFERENCE) | ||||
|  | ||||
|     with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): | ||||
|         use_kernel_mapping( | ||||
|             { | ||||
|                 "Version": { | ||||
|                     Device(type="cuda"): LayerRepository( | ||||
|                     Device(type=device): LayerRepository( | ||||
|                         repo_id="kernels-test/versions", | ||||
|                         layer_name="Version", | ||||
|                         revision="v0.1.0", | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	