mirror of
				https://github.com/huggingface/kernels.git
				synced 2025-10-31 03:14:29 +08:00 
			
		
		
		
	Compare commits
	
		
			38 Commits
		
	
	
		
			sync-with-
			...
			v0.3.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1c7c87c960 | |||
| df45cf2795 | |||
| cf0413efe5 | |||
| 851c13f666 | |||
| b6a393612f | |||
| 18ecd0ce69 | |||
| b4ef1d60e5 | |||
| a40756f306 | |||
| 3671158f47 | |||
| 2ddd473cf7 | |||
| 497dffb89e | |||
| f036fd09cb | |||
| 3e4c83c798 | |||
| 4116d6019e | |||
| bd166b348a | |||
| 386c2a104e | |||
| c7516b9e50 | |||
| a8dcd1f6bc | |||
| af7fdf9202 | |||
| 9426e7e290 | |||
| df2c165d61 | |||
| d89239464a | |||
| 3212affd9e | |||
| 7ff40a859c | |||
| cf64113c8b | |||
| ba4f88f5aa | |||
| d61971ad46 | |||
| d7f3831992 | |||
| 03875be8a0 | |||
| e41ef2358e | |||
| aca3ce7dfb | |||
| 3bae6fca7d | |||
| cffbafa61f | |||
| 29b27a58cf | |||
| bee46be22b | |||
| e05ba73534 | |||
| 544354cb97 | |||
| 105704b910 | 
							
								
								
									
										119
									
								
								.github/workflows/docker-build-matrix.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										119
									
								
								.github/workflows/docker-build-matrix.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,119 +0,0 @@ | |||||||
| name: Docker Build Matrix |  | ||||||
|  |  | ||||||
| on: |  | ||||||
|   push: |  | ||||||
|     branches: [main] |  | ||||||
|   pull_request: |  | ||||||
|     branches: [main] |  | ||||||
|     types: [opened, synchronize, reopened] # trigger on PRs |  | ||||||
|   workflow_dispatch: |  | ||||||
|  |  | ||||||
| concurrency: |  | ||||||
|   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} |  | ||||||
|   cancel-in-progress: true |  | ||||||
|  |  | ||||||
| jobs: |  | ||||||
|   build: |  | ||||||
|     name: Build Docker Image |  | ||||||
|     runs-on: |  | ||||||
|       group: aws-g6-24xlarge |  | ||||||
|     permissions: |  | ||||||
|       contents: read |  | ||||||
|       packages: write |  | ||||||
|     strategy: |  | ||||||
|       max-parallel: 4 |  | ||||||
|       matrix: |  | ||||||
|         # python: ["3.10", "3.11", "3.12"] |  | ||||||
|         # ubuntu: ["18.04", "20.04", "22.04"] |  | ||||||
|         # cuda: ["11.8.0", "12.1.0", "12.2.0", "12.4.0", "12.6.0"] |  | ||||||
|         # torch: ["2.4.0", "2.5.0"] |  | ||||||
|         include: |  | ||||||
|           - ubuntu: "18.04" |  | ||||||
|             cuda: "11.8.0" |  | ||||||
|             torch: "2.4.0" |  | ||||||
|             python: "3.10" |  | ||||||
|           - ubuntu: "22.04" |  | ||||||
|             cuda: "12.4.0" |  | ||||||
|             torch: "2.5.1" |  | ||||||
|             python: "3.12" |  | ||||||
|  |  | ||||||
|     steps: |  | ||||||
|       - name: Checkout code |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|  |  | ||||||
|       - name: Set up Docker Buildx |  | ||||||
|         uses: docker/setup-buildx-action@v3 |  | ||||||
|  |  | ||||||
|       - name: Generate Docker metadata |  | ||||||
|         id: meta |  | ||||||
|         uses: docker/metadata-action@v5 |  | ||||||
|         with: |  | ||||||
|           images: ghcr.io/${{ github.repository }}/hf_kernels |  | ||||||
|           tags: | |  | ||||||
|             type=raw,value=${{ matrix.cuda }}-${{ matrix.torch }}-python${{ matrix.python }}-ubuntu${{ matrix.ubuntu }} |  | ||||||
|             type=sha,prefix=${{ matrix.cuda }}-${{ matrix.torch }}-python${{ matrix.python }}-ubuntu${{ matrix.ubuntu }}- |  | ||||||
|  |  | ||||||
|       - name: Build Docker image |  | ||||||
|         uses: docker/build-push-action@v5 |  | ||||||
|         with: |  | ||||||
|           context: . |  | ||||||
|           file: docker/Dockerfile |  | ||||||
|           platforms: linux/amd64 |  | ||||||
|           build-args: | |  | ||||||
|             PYTHON_VERSION=${{ matrix.python }} |  | ||||||
|             UBUNTU_VERSION=${{ matrix.ubuntu }} |  | ||||||
|             CUDA_VERSION=${{ matrix.cuda }} |  | ||||||
|             TORCH_VERSION=${{ matrix.torch }} |  | ||||||
|           push: false |  | ||||||
|           load: true |  | ||||||
|           tags: ${{ steps.meta.outputs.tags }} |  | ||||||
|           labels: ${{ steps.meta.outputs.labels }} |  | ||||||
|           cache-from: type=gha,name=hf-kernels-cache-${{ matrix.ubuntu }}-${{ matrix.python }}-${{ matrix.cuda }}-${{ matrix.torch }} |  | ||||||
|           cache-to: type=gha,name=hf-kernels-cache-${{ matrix.ubuntu }}-${{ matrix.python }}-${{ matrix.cuda }}-${{ matrix.torch }} |  | ||||||
|  |  | ||||||
|       - name: Save Docker image |  | ||||||
|         run: | |  | ||||||
|           IMAGE_TAG="${{ steps.meta.outputs.tags }}" |  | ||||||
|           # Get the first tag if multiple tags are present |  | ||||||
|           FIRST_TAG=$(echo "$IMAGE_TAG" | head -n 1) |  | ||||||
|           docker save -o /tmp/docker-image-${{ matrix.cuda }}-${{ matrix.torch }}-python${{ matrix.python }}-ubuntu${{ matrix.ubuntu }}.tar "$FIRST_TAG" |  | ||||||
|  |  | ||||||
|       # Note: recommended to upload images via artifacts to share acrross jobs |  | ||||||
|       # https://docs.docker.com/build/ci/github-actions/share-image-jobs/ |  | ||||||
|       - name: Upload Docker image artifact |  | ||||||
|         uses: actions/upload-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           name: docker-image-${{ matrix.cuda }}-${{ matrix.torch }}-python${{ matrix.python }}-ubuntu${{ matrix.ubuntu }} |  | ||||||
|           path: /tmp/docker-image-${{ matrix.cuda }}-${{ matrix.torch }}-python${{ matrix.python }}-ubuntu${{ matrix.ubuntu }}.tar |  | ||||||
|           retention-days: 1 |  | ||||||
|  |  | ||||||
|   test: |  | ||||||
|     needs: build |  | ||||||
|     name: Test Docker Images |  | ||||||
|     runs-on: |  | ||||||
|       group: aws-g6-12xlarge-plus |  | ||||||
|     steps: |  | ||||||
|       - name: Checkout code |  | ||||||
|         uses: actions/checkout@v4 |  | ||||||
|  |  | ||||||
|       - name: Download all Docker images |  | ||||||
|         uses: actions/download-artifact@v4 |  | ||||||
|         with: |  | ||||||
|           pattern: docker-image-* |  | ||||||
|           path: /tmp |  | ||||||
|           merge-multiple: true |  | ||||||
|  |  | ||||||
|       - name: Load and test Docker images |  | ||||||
|         run: | |  | ||||||
|           for image_tar in /tmp/docker-image-*.tar; do |  | ||||||
|               echo "Processing image $image_tar" |  | ||||||
|               # Extract the version tag from the filename without the 'docker-image-' prefix |  | ||||||
|               docker_tag=$(basename $image_tar .tar | sed 's/^docker-image-//') |  | ||||||
|               echo "Loading image with tag $docker_tag" |  | ||||||
|               docker load -i $image_tar |  | ||||||
|               echo "Loaded image $docker_tag" |  | ||||||
|               docker run --gpus all \ |  | ||||||
|                   -v /home/runner/_work/hf-kernels/hf-kernels/tests:/workspace/tests \ |  | ||||||
|                   ghcr.io/huggingface/hf-kernels/hf_kernels:$docker_tag |  | ||||||
|               echo "Tested image $docker_tag" |  | ||||||
|           done |  | ||||||
							
								
								
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/lint.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | |||||||
|  | name: Lints | ||||||
|  | on: [push, pull_request] | ||||||
|  | jobs: | ||||||
|  |   lint: | ||||||
|  |     name: Run lints | ||||||
|  |     runs-on: ubuntu-latest | ||||||
|  |     steps: | ||||||
|  |       - uses: actions/checkout@v4 | ||||||
|  |       - name: Run ruff | ||||||
|  |         uses: astral-sh/ruff-action@v3 | ||||||
							
								
								
									
										54
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								.github/workflows/test.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | |||||||
|  | name: Test kernels | ||||||
|  |  | ||||||
|  | on: | ||||||
|  |   push: | ||||||
|  |     branches: [main] | ||||||
|  |   pull_request: | ||||||
|  |     branches: [main] | ||||||
|  |     types: [opened, synchronize, reopened] # trigger on PRs | ||||||
|  |   workflow_dispatch: | ||||||
|  |  | ||||||
|  | concurrency: | ||||||
|  |   group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ||||||
|  |   cancel-in-progress: true | ||||||
|  |  | ||||||
|  | jobs: | ||||||
|  |   build: | ||||||
|  |     name: Run tests | ||||||
|  |     runs-on: | ||||||
|  |       group: aws-g6-24xlarge | ||||||
|  |     permissions: | ||||||
|  |       contents: read | ||||||
|  |       packages: write | ||||||
|  |     strategy: | ||||||
|  |       max-parallel: 4 | ||||||
|  |       matrix: | ||||||
|  |         python-version: ["3.10", "3.12"] | ||||||
|  |         torch-version: ["2.5.1", "2.6.0"] | ||||||
|  |  | ||||||
|  |     env: | ||||||
|  |       UV_PYTHON_PREFERENCE: only-managed | ||||||
|  |  | ||||||
|  |     steps: | ||||||
|  |       - name: Checkout code | ||||||
|  |         uses: actions/checkout@v4 | ||||||
|  |  | ||||||
|  |       - name: Install uv and set the python version | ||||||
|  |         uses: astral-sh/setup-uv@v5 | ||||||
|  |         with: | ||||||
|  |           python-version: ${{ matrix.python-version }} | ||||||
|  |  | ||||||
|  |       - name: Lock Torch version | ||||||
|  |         run: uv lock --upgrade-package "torch==${{ matrix.torch-version }}" | ||||||
|  |  | ||||||
|  |       - name: Install the project | ||||||
|  |         run: uv sync --all-extras --dev | ||||||
|  |  | ||||||
|  |       - name: Install setuptools for Triton-based test | ||||||
|  |         run: uv pip install setuptools | ||||||
|  |  | ||||||
|  |       - name: Check typing | ||||||
|  |         run: uv run mypy src/kernels | ||||||
|  |  | ||||||
|  |       - name: Run tests | ||||||
|  |         run: uv run pytest tests | ||||||
							
								
								
									
										40
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										40
									
								
								README.md
									
									
									
									
									
								
							| @ -1,11 +1,31 @@ | |||||||
| # hf-kernels | # kernels | ||||||
|  |  | ||||||
| Make sure you have `torch==2.5.1+cu124` installed. | The Kernel Hub allows Python libraries and applications to load compute | ||||||
|  | kernels directly from the [Hub](https://hf.co/). To support this kind | ||||||
|  | of dynamic loading, Hub kernels differ from traditional Python kernel | ||||||
|  | packages in that they are made to be: | ||||||
|  |  | ||||||
|  | - Portable: a kernel can be loaded from paths outside `PYTHONPATH`. | ||||||
|  | - Unique: multiple versions of the same kernel can be loaded in the | ||||||
|  |   same Python process. | ||||||
|  | - Compatible: kernels must support all recent versions of Python and | ||||||
|  |   the different PyTorch build configurations (various CUDA versions | ||||||
|  |   and C++ ABIs). Furthermore, older C library versions must be supported. | ||||||
|  |  | ||||||
|  | ## 🚀 Quick Start | ||||||
|  |  | ||||||
|  | Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | pip install kernels | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: | ||||||
|  |  | ||||||
| ```python | ```python | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
| from hf_kernels import get_kernel | from kernels import get_kernel | ||||||
|  |  | ||||||
| # Download optimized kernels from the Hugging Face hub | # Download optimized kernels from the Hugging Face hub | ||||||
| activation = get_kernel("kernels-community/activation") | activation = get_kernel("kernels-community/activation") | ||||||
| @ -20,11 +40,13 @@ activation.gelu_fast(y, x) | |||||||
| print(y) | print(y) | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ## Docker Reference | You can [search for kernels](https://huggingface.co/models?other=kernel) on | ||||||
|  | the Hub. | ||||||
|  |  | ||||||
| build and run the reference [example/basic.py](example/basic.py) in a Docker container with the following commands: | ## 📚 Documentation | ||||||
|  |  | ||||||
| ```bash | - [Using layers](docs/layers.md) | ||||||
| docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | - [Locking kernel versions](docs/locking.md) | ||||||
| docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | - [Using kernels in a Docker container](docs/docker.md) | ||||||
| ``` | - [Kernel requirements](docs/kernel-requirements.md) | ||||||
|  | - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
|  | |||||||
| @ -1,81 +0,0 @@ | |||||||
| # syntax=docker/dockerfile:1.4 |  | ||||||
| ARG PYTHON_VERSION=3.10 |  | ||||||
| ARG CUDA_VERSION=12.4.0 |  | ||||||
| ARG UBUNTU_VERSION=20.04 |  | ||||||
| ARG TORCH_VERSION=2.5.0 |  | ||||||
|  |  | ||||||
| FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} as base |  | ||||||
|  |  | ||||||
| # Set environment variables |  | ||||||
| ENV DEBIAN_FRONTEND=noninteractive \ |  | ||||||
|     PYTHONUNBUFFERED=1 \ |  | ||||||
|     PATH="/root/.local/bin:/root/.cargo/bin:${PATH}" \ |  | ||||||
|     NVIDIA_VISIBLE_DEVICES=all \ |  | ||||||
|     NVIDIA_DRIVER_CAPABILITIES=compute,utility |  | ||||||
|  |  | ||||||
| # Install system dependencies |  | ||||||
| RUN apt-get update && apt-get install -y --no-install-recommends \ |  | ||||||
|     git \ |  | ||||||
|     git-lfs \ |  | ||||||
|     curl \ |  | ||||||
|     python3 \ |  | ||||||
|     python3-pip \ |  | ||||||
|     && rm -rf /var/lib/apt/lists/* \ |  | ||||||
|     && git lfs install |  | ||||||
|  |  | ||||||
| # Install uv package manager |  | ||||||
| RUN curl -LsSf https://astral.sh/uv/install.sh | sh |  | ||||||
|  |  | ||||||
| # Set working directory |  | ||||||
| WORKDIR /app |  | ||||||
|  |  | ||||||
| # Need to re-declare ARG after FROM for use in RUN |  | ||||||
| ARG CUDA_VERSION |  | ||||||
| ARG TORCH_VERSION |  | ||||||
| ARG PYTHON_VERSION |  | ||||||
|  |  | ||||||
| RUN echo "Building with CUDA_VERSION=${CUDA_VERSION}, TORCH_VERSION=${TORCH_VERSION}, PYTHON_VERSION=${PYTHON_VERSION}" |  | ||||||
|  |  | ||||||
| # Initialize uv and create virtual env |  | ||||||
| RUN uv init --app kernel-test --python "${PYTHON_VERSION}" |  | ||||||
|  |  | ||||||
| # Move into the app |  | ||||||
| WORKDIR /app/kernel-test |  | ||||||
|  |  | ||||||
| # Install PyTorch with the appropriate CUDA version |  | ||||||
|  |  | ||||||
| # NOTE: `markupsafe` must be installed first to avoid a conflict with the torch package.  |  | ||||||
| # See: https://github.com/astral-sh/uv/issues/9647 |  | ||||||
|  |  | ||||||
| RUN CUDA_MAJOR_MINOR=$(echo ${CUDA_VERSION} | cut -d'.' -f1,2) && \ |  | ||||||
|     case ${CUDA_MAJOR_MINOR} in \ |  | ||||||
|     "11.8") CUDA_TAG="cu118" ;; \ |  | ||||||
|     "12.1") CUDA_TAG="cu121" ;; \ |  | ||||||
|     "12.2") CUDA_TAG="cu122" ;; \ |  | ||||||
|     "12.4") CUDA_TAG="cu124" ;; \ |  | ||||||
|     *) CUDA_TAG="" ;; \ |  | ||||||
|     esac && \ |  | ||||||
|     if [ -n "${CUDA_TAG}" ]; then \ |  | ||||||
|     echo "Installing PyTorch ${TORCH_VERSION} with CUDA ${CUDA_TAG}" && \ |  | ||||||
|     uv add markupsafe --default-index "https://pypi.org/simple" && \ |  | ||||||
|     uv add "torch==${TORCH_VERSION}" --index-url "https://download.pytorch.org/whl/${CUDA_TAG}"; \ |  | ||||||
|     else \ |  | ||||||
|     echo "Installing PyTorch ${TORCH_VERSION} without CUDA-specific index" && \ |  | ||||||
|     uv add "torch==${TORCH_VERSION}"; \ |  | ||||||
|     fi |  | ||||||
|  |  | ||||||
| # add pytest for runtime tests |  | ||||||
| RUN uv add pytest pytest-benchmark huggingface_hub |  | ||||||
|  |  | ||||||
| # Copy application files |  | ||||||
| COPY src ./hf_kernels/src |  | ||||||
| COPY pyproject.toml ./hf_kernels/pyproject.toml |  | ||||||
| COPY README.md ./hf_kernels/README.md |  | ||||||
| COPY examples ./examples |  | ||||||
| COPY tests ./tests |  | ||||||
|  |  | ||||||
| # Install the kernel library |  | ||||||
| RUN uv pip install hf_kernels |  | ||||||
|  |  | ||||||
| # Run tests and benchmarks |  | ||||||
| CMD [".venv/bin/pytest", "tests", "-v"]  |  | ||||||
| @ -31,13 +31,13 @@ WORKDIR /app/kernel-test | |||||||
| # install python depdencies | # install python depdencies | ||||||
| RUN uv add torch==2.5.0 numpy | RUN uv add torch==2.5.0 numpy | ||||||
|  |  | ||||||
| # copy hf-kernels lib | # copy kernels lib | ||||||
| COPY src ./hf-kernels/src | COPY src ./kernels/src | ||||||
| COPY pyproject.toml ./hf-kernels/pyproject.toml | COPY pyproject.toml ./kernels/pyproject.toml | ||||||
| COPY README.md ./hf-kernels/README.md | COPY README.md ./kernels/README.md | ||||||
|  |  | ||||||
| # install library | # install library | ||||||
| RUN uv pip install -e hf-kernels | RUN uv pip install -e kernels | ||||||
|  |  | ||||||
| # copy examples | # copy examples | ||||||
| COPY examples ./examples | COPY examples ./examples | ||||||
| @ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility | |||||||
|  |  | ||||||
| # command to run the script | # command to run the script | ||||||
| CMD ["uv", "run", "examples/basic.py"] | CMD ["uv", "run", "examples/basic.py"] | ||||||
| # CMD ["ls", "hf-kernels"] | # CMD ["ls", "kernels"] | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								docs/docker.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | |||||||
|  | # Using kernels in a Docker container | ||||||
|  |  | ||||||
|  | build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . | ||||||
|  | docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference | ||||||
|  | ``` | ||||||
							
								
								
									
										177
									
								
								docs/kernel-requirements.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										177
									
								
								docs/kernel-requirements.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,177 @@ | |||||||
|  | # Kernel requirements | ||||||
|  |  | ||||||
|  | Kernels on the Hub must fulfill the requirements outlined on this page. | ||||||
|  | You can use [kernel-builder](https://github.com/huggingface/kernel-builder/) | ||||||
|  | to build conforming kernels. | ||||||
|  |  | ||||||
|  | ## Directory layout | ||||||
|  |  | ||||||
|  | A kernel repository on the Hub must contain a `build` directory. This | ||||||
|  | directory contains build variants of a kernel in the form of directories | ||||||
|  | following the template | ||||||
|  | `<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`. | ||||||
|  | For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently | ||||||
|  | recommended build variants are: | ||||||
|  |  | ||||||
|  | - `torch25-cxx11-cu118-x86_64-linux` | ||||||
|  | - `torch25-cxx11-cu121-x86_64-linux` | ||||||
|  | - `torch25-cxx11-cu124-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu118-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu121-x86_64-linux` | ||||||
|  | - `torch25-cxx98-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu118-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx11-cu126-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu118-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu124-x86_64-linux` | ||||||
|  | - `torch26-cxx98-cu126-x86_64-linux` | ||||||
|  |  | ||||||
|  | This list will be updated as new PyTorch versions are released. Kernels | ||||||
|  | that are in pure Python (e.g. Triton kernels) only need to provide a | ||||||
|  | single build variant: | ||||||
|  |  | ||||||
|  | - `torch-universal` | ||||||
|  |  | ||||||
|  | Each variant directory should contain a single directory with the same name | ||||||
|  | as the repository (replacing `-` by `_`). For instance, kernels in the | ||||||
|  | `kernels-community/activation` repository have a directories like | ||||||
|  | `build/<variant>/activation`. This directory | ||||||
|  | must be a Python package with an `__init__.py` file. | ||||||
|  |  | ||||||
|  | ## Native Python module | ||||||
|  |  | ||||||
|  | Kernels will typically contain a native Python module with precompiled | ||||||
|  | compute kernels and bindings. This module must fulfill the following | ||||||
|  | requirements: | ||||||
|  |  | ||||||
|  | - Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface) | ||||||
|  |   for compatibility with Python 3.9 and later. | ||||||
|  | - Compatible with glibc 2.27 or later. This means that no symbols | ||||||
|  |   from later versions must be used. To archive this, the module should | ||||||
|  |   be built against this glibc version. **Warning:** libgcc must also be | ||||||
|  |   built against glibc 2.27 to avoid leaking symbols. | ||||||
|  | - No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols | ||||||
|  |   must be static. | ||||||
|  | - No dynamic library dependencies outside Torch or CUDA libraries | ||||||
|  |   installed as dependencies of Torch. | ||||||
|  |  | ||||||
|  | (These requirements will be updated as new PyTorch versions are released.) | ||||||
|  |  | ||||||
|  | ## Torch extension | ||||||
|  |  | ||||||
|  | Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial) | ||||||
|  | in `torch.ops.<namespace>`. Since we allow loading of multiple versions of | ||||||
|  | a module in the same Python process, `namespace` must be unique for each | ||||||
|  | version of a kernel. Failing to do so will create clashes when different | ||||||
|  | versions of the same kernel are loaded. Two suggested ways of doing this | ||||||
|  | are: | ||||||
|  |  | ||||||
|  | - Appending a truncated SHA-1 hash of the git commit that the kernel was | ||||||
|  |   built from to the name of the extension. | ||||||
|  | - Appending random material to the name of the extension. | ||||||
|  |  | ||||||
|  | **Note:** we recommend against appending a version number or git tag. | ||||||
|  | Version numbers are typically not bumped on each commit, so users | ||||||
|  | might use two different commits that happen to have the same version | ||||||
|  | number. Git tags are not stable, so they do not provide a good way | ||||||
|  | of guaranteeing uniqueness of the namespace. | ||||||
|  |  | ||||||
|  | ## Layers | ||||||
|  |  | ||||||
|  | A kernel can provide layers in addition to kernel functions. A layer from | ||||||
|  | the Hub can replace the `forward` method of an existing layer for a certain | ||||||
|  | device type. This makes it possible to provide more performant kernels for | ||||||
|  | existing layers. See the [layers documentation](layers.md) for more information | ||||||
|  | on how to use layers. | ||||||
|  |  | ||||||
|  | ### Writing layers | ||||||
|  |  | ||||||
|  | To make the extension of layers safe, the layers must fulfill the following | ||||||
|  | requirements: | ||||||
|  |  | ||||||
|  | - The layers are subclasses of `torch.nn.Module`. | ||||||
|  | - The layers are pure, meaning that they do not have their own state. This | ||||||
|  |   means that: | ||||||
|  |   - The layer must not define its own constructor. | ||||||
|  |   - The layer must not use class variables. | ||||||
|  | - No other methods must be defined than `forward`. | ||||||
|  | - The `forward` method has a signature that is compatible with the | ||||||
|  |   `forward` method that it is extending. | ||||||
|  |  | ||||||
|  | This is an example of a pure layer: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | class SiluAndMul(nn.Module): | ||||||
|  |     def forward(self, x: torch.Tensor): | ||||||
|  |         d = x.shape[-1] // 2 | ||||||
|  |         output_shape = x.shape[:-1] + (d,) | ||||||
|  |         out = torch.empty(output_shape, dtype=x.dtype, device=x.device) | ||||||
|  |         ops.silu_and_mul(out, x) | ||||||
|  |         return out | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | For some layers, the `forward` method has to use state from the adopting class. | ||||||
|  | In these cases, we recommend to use type annotations to indicate what member | ||||||
|  | variables are expected. For instance: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | class LlamaRMSNorm(nn.Module): | ||||||
|  |     weight: torch.Tensor | ||||||
|  |     variance_epsilon: float | ||||||
|  |  | ||||||
|  |     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||||
|  |         return rms_norm_fn( | ||||||
|  |             hidden_states, | ||||||
|  |             self.weight, | ||||||
|  |             bias=None, | ||||||
|  |             residual=None, | ||||||
|  |             eps=self.variance_epsilon, | ||||||
|  |             dropout_p=0.0, | ||||||
|  |             prenorm=False, | ||||||
|  |             residual_in_fp32=False, | ||||||
|  |         ) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This layer expects the adopting layer to have `weight` and `variance_epsilon` | ||||||
|  | member variables and uses them in the `forward` method. | ||||||
|  |  | ||||||
|  | ### Exporting layers | ||||||
|  |  | ||||||
|  | To accommodate portable loading, `layers` must be defined in the main | ||||||
|  | `__init__.py` file. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from . import layers | ||||||
|  |  | ||||||
|  | __all__ = [ | ||||||
|  |   # ... | ||||||
|  |   "layers" | ||||||
|  |   # ... | ||||||
|  | ] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Python requirements | ||||||
|  |  | ||||||
|  | - Python code must be compatible with Python 3.9 and later. | ||||||
|  | - All Python code imports from the kernel itself must be relative. So, | ||||||
|  |   for instance if in the example kernel `example`, | ||||||
|  |   `module_b` needs a function from `module_a`, import as: | ||||||
|  |  | ||||||
|  |   ```python | ||||||
|  |   from .module_a import foo | ||||||
|  |   ``` | ||||||
|  |  | ||||||
|  |   **Never use:** | ||||||
|  |  | ||||||
|  |   ```python | ||||||
|  |   # DO NOT DO THIS! | ||||||
|  |  | ||||||
|  |   from example.module_a import foo | ||||||
|  |   ``` | ||||||
|  |  | ||||||
|  |   The latter would import from the module `example` that is in Python's | ||||||
|  |   global module dict. However, since we allow loading multiple versions | ||||||
|  |   of a module, we uniquely name the module. | ||||||
|  |  | ||||||
|  | - Only modules from the Python standard library, Torch, or the kernel itself | ||||||
|  |   can be imported. | ||||||
							
								
								
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								docs/layers.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,79 @@ | |||||||
|  | # Layers | ||||||
|  |  | ||||||
|  | A kernel can provide layers in addition to kernel functions. A layer from | ||||||
|  | the Hub can replace the `forward` method of an existing layer for a certain | ||||||
|  | device type. This makes it possible to provide more performant kernels for | ||||||
|  | existing layers. | ||||||
|  |  | ||||||
|  | See [Kernel requirements](kernel-requirements.md) for more information the | ||||||
|  | requirements of Hub layers. | ||||||
|  |  | ||||||
|  | ## Making a layer extensible with kernels from the hub | ||||||
|  |  | ||||||
|  | ### Using a decorator | ||||||
|  |  | ||||||
|  | A layer can be made extensible with the `use_kernel_forward_from_hub` | ||||||
|  | decorator. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | @use_kernel_forward_from_hub("SiluAndMul") | ||||||
|  | class SiluAndMul(nn.Module): | ||||||
|  |     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         d = input.shape[-1] // 2 | ||||||
|  |         return F.silu(input[..., :d]) * input[..., d:] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The decorator changes the layer, so that other implementations of the `forward` | ||||||
|  | method can be registered using the name `SiluAndMul`. | ||||||
|  |  | ||||||
|  | ### External layers | ||||||
|  |  | ||||||
|  | An existing layer that does not (yet) have the `use_kernel_forward_from_hub` | ||||||
|  | decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function. | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from somelibrary import SiluAndMul | ||||||
|  |  | ||||||
|  | replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | The `register_kernel_mapping` call maps the name `SiluAndMul` to actual | ||||||
|  | hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer) | ||||||
|  | section for more information. | ||||||
|  |  | ||||||
|  | **Warning:** we strongly recommend using layers with a decorator, since | ||||||
|  | it signifies that the maintainer intends to keep the `forward` signature | ||||||
|  | compatible with layers from the hub. | ||||||
|  |  | ||||||
|  | ## Registering a hub kernel for a layer | ||||||
|  |  | ||||||
|  | Once a layer is made extensible, users can register hub kernels for it | ||||||
|  | by name using the `register_kernel_mapping` function. For example: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         "cuda": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             revision="layers", | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This will register the kernel mapping in the current context, which is | ||||||
|  | normally global. It is recommended to scope the mapping to where it is | ||||||
|  | used with the `use_kernel_mapping` context manager: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | with use_kernel_mapping(kernel_layer_mapping): | ||||||
|  |     # Use the layer for which the mapping is applied. | ||||||
|  |     ... | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | This ensures that the mapping is not active anymore outside the | ||||||
|  | `with`-scope. | ||||||
							
								
								
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								docs/locking.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | |||||||
|  | # Locking kernel versions | ||||||
|  |  | ||||||
|  | Projects that use `setuptools` can lock the kernel versions that should be | ||||||
|  | used. First specify the accepted versions in `pyproject.toml` and make | ||||||
|  | sure that `kernels` is a build dependency: | ||||||
|  |  | ||||||
|  | ```toml | ||||||
|  | [build-system] | ||||||
|  | requires = ["kernels", "setuptools"] | ||||||
|  | build-backend = "setuptools.build_meta" | ||||||
|  |  | ||||||
|  | [tool.kernels.dependencies] | ||||||
|  | "kernels-community/activation" = ">=0.0.1" | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with | ||||||
|  | the locked revisions. The locked revision will be used when loading a kernel with | ||||||
|  | `get_locked_kernel`: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from kernels import get_locked_kernel | ||||||
|  |  | ||||||
|  | activation = get_locked_kernel("kernels-community/activation") | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Note:** the lock file is included in the package metadata, so it will only be visible | ||||||
|  | to `kernels` after doing an (editable or regular) installation of your project. | ||||||
|  |  | ||||||
|  | ## Pre-downloading locked kernels | ||||||
|  |  | ||||||
|  | Locked kernels can be pre-downloaded by running `kernel download .` in your | ||||||
|  | project directory. This will download the kernels to your local Hugging Face | ||||||
|  | Hub cache. | ||||||
|  |  | ||||||
|  | The pre-downloaded kernels are used by the `get_locked_kernel` function. | ||||||
|  | `get_locked_kernel` will download a kernel when it is not pre-downloaded. If you | ||||||
|  | want kernel loading to error when a kernel is not pre-downloaded, you can use | ||||||
|  | the `load_kernel` function instead: | ||||||
|  |  | ||||||
|  | ```python | ||||||
|  | from kernels import load_kernel | ||||||
|  |  | ||||||
|  | activation = load_kernel("kernels-community/activation") | ||||||
|  | ``` | ||||||
| @ -1,6 +1,6 @@ | |||||||
| import torch | import torch | ||||||
|  |  | ||||||
| from hf_kernels import get_kernel | from kernels import get_kernel | ||||||
|  |  | ||||||
| print("Starting examples/basic.py demo") | print("Starting examples/basic.py demo") | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								flake.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | |||||||
|  | { | ||||||
|  |   "nodes": { | ||||||
|  |     "flake-compat": { | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1733328505, | ||||||
|  |         "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", | ||||||
|  |         "owner": "edolstra", | ||||||
|  |         "repo": "flake-compat", | ||||||
|  |         "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "edolstra", | ||||||
|  |         "repo": "flake-compat", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "flake-utils": { | ||||||
|  |       "inputs": { | ||||||
|  |         "systems": "systems" | ||||||
|  |       }, | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1731533236, | ||||||
|  |         "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", | ||||||
|  |         "owner": "numtide", | ||||||
|  |         "repo": "flake-utils", | ||||||
|  |         "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "numtide", | ||||||
|  |         "repo": "flake-utils", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "flake-utils_2": { | ||||||
|  |       "inputs": { | ||||||
|  |         "systems": "systems_2" | ||||||
|  |       }, | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1731533236, | ||||||
|  |         "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", | ||||||
|  |         "owner": "numtide", | ||||||
|  |         "repo": "flake-utils", | ||||||
|  |         "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "numtide", | ||||||
|  |         "repo": "flake-utils", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "nixpkgs": { | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1737453259, | ||||||
|  |         "narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=", | ||||||
|  |         "owner": "danieldk", | ||||||
|  |         "repo": "nixpkgs", | ||||||
|  |         "rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "danieldk", | ||||||
|  |         "ref": "outlines-v0.1.4-tgi", | ||||||
|  |         "repo": "nixpkgs", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "root": { | ||||||
|  |       "inputs": { | ||||||
|  |         "flake-utils": "flake-utils", | ||||||
|  |         "nixpkgs": [ | ||||||
|  |           "tgi-nix", | ||||||
|  |           "nixpkgs" | ||||||
|  |         ], | ||||||
|  |         "tgi-nix": "tgi-nix" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "systems": { | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1681028828, | ||||||
|  |         "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", | ||||||
|  |         "owner": "nix-systems", | ||||||
|  |         "repo": "default", | ||||||
|  |         "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "nix-systems", | ||||||
|  |         "repo": "default", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "systems_2": { | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1681028828, | ||||||
|  |         "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", | ||||||
|  |         "owner": "nix-systems", | ||||||
|  |         "repo": "default", | ||||||
|  |         "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "nix-systems", | ||||||
|  |         "repo": "default", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     }, | ||||||
|  |     "tgi-nix": { | ||||||
|  |       "inputs": { | ||||||
|  |         "flake-compat": "flake-compat", | ||||||
|  |         "flake-utils": "flake-utils_2", | ||||||
|  |         "nixpkgs": "nixpkgs" | ||||||
|  |       }, | ||||||
|  |       "locked": { | ||||||
|  |         "lastModified": 1741617161, | ||||||
|  |         "narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=", | ||||||
|  |         "owner": "huggingface", | ||||||
|  |         "repo": "text-generation-inference-nix", | ||||||
|  |         "rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925", | ||||||
|  |         "type": "github" | ||||||
|  |       }, | ||||||
|  |       "original": { | ||||||
|  |         "owner": "huggingface", | ||||||
|  |         "ref": "kernels-0.2.0", | ||||||
|  |         "repo": "text-generation-inference-nix", | ||||||
|  |         "type": "github" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   "root": "root", | ||||||
|  |   "version": 7 | ||||||
|  | } | ||||||
							
								
								
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								flake.nix
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | |||||||
|  | { | ||||||
|  |   inputs = { | ||||||
|  |     tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0"; | ||||||
|  |     nixpkgs.follows = "tgi-nix/nixpkgs"; | ||||||
|  |     flake-utils.url = "github:numtide/flake-utils"; | ||||||
|  |   }; | ||||||
|  |   outputs = | ||||||
|  |     { | ||||||
|  |       self, | ||||||
|  |       nixpkgs, | ||||||
|  |       flake-utils, | ||||||
|  |       tgi-nix, | ||||||
|  |     }: | ||||||
|  |     flake-utils.lib.eachDefaultSystem ( | ||||||
|  |       system: | ||||||
|  |       let | ||||||
|  |         pkgs = import nixpkgs { | ||||||
|  |           inherit system; | ||||||
|  |           inherit (tgi-nix.lib) config; | ||||||
|  |           overlays = [ | ||||||
|  |             tgi-nix.overlays.default | ||||||
|  |           ]; | ||||||
|  |         }; | ||||||
|  |       in | ||||||
|  |       { | ||||||
|  |         formatter = pkgs.nixfmt-rfc-style; | ||||||
|  |         devShells = with pkgs; rec { | ||||||
|  |           default = mkShell { | ||||||
|  |             buildInputs = | ||||||
|  |               [ | ||||||
|  |                 black | ||||||
|  |                 mypy | ||||||
|  |                 pyright | ||||||
|  |                 ruff | ||||||
|  |               ] | ||||||
|  |               ++ (with python3.pkgs; [ | ||||||
|  |                 huggingface-hub | ||||||
|  |                 pytest | ||||||
|  |                 pytest-benchmark | ||||||
|  |                 torch | ||||||
|  |                 venvShellHook | ||||||
|  |               ]); | ||||||
|  |  | ||||||
|  |             venvDir = "./.venv"; | ||||||
|  |  | ||||||
|  |             postVenvCreation = '' | ||||||
|  |               unset SOURCE_DATE_EPOCH | ||||||
|  |               ( python -m pip install --no-build-isolation --no-dependencies -e . ) | ||||||
|  |             ''; | ||||||
|  |           }; | ||||||
|  |         }; | ||||||
|  |       } | ||||||
|  |     ); | ||||||
|  | } | ||||||
| @ -1,22 +1,64 @@ | |||||||
| [project] | [project] | ||||||
| name = "hf-kernels" | name = "kernels" | ||||||
| version = "0.1.0" | version = "0.3.0" | ||||||
| description = "Download cuda kernels" | description = "Download compute kernels" | ||||||
| authors = [ | authors = [ | ||||||
|     {name = "OlivierDehaene", email = "olivier@huggingface.co"}, |   { name = "OlivierDehaene", email = "olivier@huggingface.co" }, | ||||||
|     {name = "Daniel de Kok", email = "daniel@huggingface.co"}, |   { name = "Daniel de Kok", email = "daniel@huggingface.co" }, | ||||||
|     {name = "David Holtz", email = "david@huggingface.co"}, |   { name = "David Holtz", email = "david@huggingface.co" }, | ||||||
|     {name = "Nicolas Patry", email = "nicolas@huggingface.co"} |   { name = "Nicolas Patry", email = "nicolas@huggingface.co" }, | ||||||
| ] | ] | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
|  | requires-python = ">= 3.9" | ||||||
| [dependencies] | dependencies = [ | ||||||
| python = "^3.9" |   "huggingface-hub>=0.26.3", | ||||||
| huggingface-hub = "^0.26.3" |   "packaging>=24.2", | ||||||
| packaging = "^24.2" |   "tomli>=2.0.1; python_version<'3.11'", | ||||||
| tomli = { version = "^2.0.1", python = "<3.11" } |   "torch>=2.5", | ||||||
|  | ] | ||||||
|  |  | ||||||
| [build-system] | [build-system] | ||||||
| requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"] | requires = ["setuptools"] | ||||||
| build-backend = "hf_kernels.build" | build-backend = "setuptools.build_meta" | ||||||
| backend-path = ["src"] |  | ||||||
|  | [dependency-groups] | ||||||
|  | dev = [ | ||||||
|  |   "mypy == 1.14.1", | ||||||
|  |   "pytest >=8", | ||||||
|  |   # Whatever version is compatible with pytest. | ||||||
|  |   "pytest-benchmark", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [project.scripts] | ||||||
|  | kernels = "kernels.cli:main" | ||||||
|  |  | ||||||
|  | [project.entry-points."egg_info.writers"] | ||||||
|  | "kernels.lock" = "kernels.lockfile:write_egg_lockfile" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | [tool.ruff] | ||||||
|  | exclude = [ | ||||||
|  |   ".eggs", | ||||||
|  |   ".git", | ||||||
|  |   ".git-rewrite", | ||||||
|  |   ".hg", | ||||||
|  |   ".mypy_cache", | ||||||
|  |   ".nox", | ||||||
|  |   ".pants.d", | ||||||
|  |   ".pytype", | ||||||
|  |   ".ruff_cache", | ||||||
|  |   ".svn", | ||||||
|  |   ".tox", | ||||||
|  |   ".venv", | ||||||
|  |   ".venv*", | ||||||
|  |   "__pypackages__", | ||||||
|  |   "_build", | ||||||
|  |   "build", | ||||||
|  |   "dist", | ||||||
|  |   "venv", | ||||||
|  | ] | ||||||
|  | line-length = 119 | ||||||
|  | # Ignored rules: | ||||||
|  | # "E501" -> line length violation | ||||||
|  | lint.ignore = ["E501"] | ||||||
|  | lint.select = ["E", "F", "I", "W"] | ||||||
|  | |||||||
| @ -1,3 +0,0 @@ | |||||||
| from hf_kernels.utils import get_kernel, load_kernel, install_kernel |  | ||||||
|  |  | ||||||
| __all__ = ["get_kernel", "load_kernel", "install_kernel"] |  | ||||||
| @ -1,149 +0,0 @@ | |||||||
| """ |  | ||||||
| Python shims for the PEP 517 and PEP 660 build backend. |  | ||||||
|  |  | ||||||
| Major imports in this module are required to be lazy: |  | ||||||
| ``` |  | ||||||
| $ hyperfine \ |  | ||||||
|      "/usr/bin/python3 -c \"print('hi')\"" \ |  | ||||||
|      "/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\"" |  | ||||||
| Base: Time (mean ± σ):      11.0 ms ±   1.7 ms    [User: 8.5 ms, System: 2.5 ms] |  | ||||||
| With import: Time (mean ± σ):      15.2 ms ±   2.0 ms    [User: 12.3 ms, System: 2.9 ms] |  | ||||||
| Base 1.38 ± 0.28 times faster than with import |  | ||||||
| ``` |  | ||||||
|  |  | ||||||
| The same thing goes for the typing module, so we use Python 3.10 type annotations that |  | ||||||
| don't require importing typing but then quote them so earlier Python version ignore |  | ||||||
| them while IDEs and type checker can see through the quotes. |  | ||||||
| """ |  | ||||||
|  |  | ||||||
| import sys |  | ||||||
|  |  | ||||||
| TYPE_CHECKING = False |  | ||||||
| if TYPE_CHECKING: |  | ||||||
|     from collections.abc import Mapping, Sequence  # noqa:I001 |  | ||||||
|     from typing import Any  # noqa:I001 |  | ||||||
|  |  | ||||||
| if sys.version_info >= (3, 11): |  | ||||||
|     import tomllib |  | ||||||
| else: |  | ||||||
|     import tomli as tomllib |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None: |  | ||||||
|     import sys |  | ||||||
|  |  | ||||||
|     if config_settings: |  | ||||||
|         print("Warning: Config settings are not supported", file=sys.stderr) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def call( |  | ||||||
|     args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None |  | ||||||
| ) -> str: |  | ||||||
|     """Invoke a uv subprocess and return the filename from stdout.""" |  | ||||||
|     import shutil |  | ||||||
|     import subprocess |  | ||||||
|     import sys |  | ||||||
|  |  | ||||||
|     warn_config_settings(config_settings) |  | ||||||
|     # Unlike `find_uv_bin`, this mechanism must work according to PEP 517 |  | ||||||
|     import os |  | ||||||
|  |  | ||||||
|     cwd = os.getcwd() |  | ||||||
|     filename = os.path.join(cwd, "pyproject.toml") |  | ||||||
|     with open(filename, "rb") as f: |  | ||||||
|         data = tomllib.load(f) |  | ||||||
|  |  | ||||||
|     for kernel, _ in ( |  | ||||||
|         data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items() |  | ||||||
|     ): |  | ||||||
|         from hf_kernels.utils import install_kernel |  | ||||||
|  |  | ||||||
|         install_kernel(kernel, revision="main") |  | ||||||
|     uv_bin = shutil.which("uv") |  | ||||||
|     if uv_bin is None: |  | ||||||
|         raise RuntimeError("uv was not properly installed") |  | ||||||
|     # Forward stderr, capture stdout for the filename |  | ||||||
|     result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE) |  | ||||||
|     if result.returncode != 0: |  | ||||||
|         sys.exit(result.returncode) |  | ||||||
|     # If there was extra stdout, forward it (there should not be extra stdout) |  | ||||||
|     stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True) |  | ||||||
|     sys.stdout.writelines(stdout[:-1]) |  | ||||||
|     # Fail explicitly instead of an irrelevant stacktrace |  | ||||||
|     if not stdout: |  | ||||||
|         print("uv subprocess did not return a filename on stdout", file=sys.stderr) |  | ||||||
|         sys.exit(1) |  | ||||||
|     return stdout[-1].strip() |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_sdist( |  | ||||||
|     sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None |  | ||||||
| ) -> str: |  | ||||||
|     """PEP 517 hook `build_sdist`.""" |  | ||||||
|     args = ["build-backend", "build-sdist", sdist_directory] |  | ||||||
|     return call(args, config_settings) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_wheel( |  | ||||||
|     wheel_directory: str, |  | ||||||
|     config_settings: "Mapping[Any, Any] | None" = None, |  | ||||||
|     metadata_directory: "str | None" = None, |  | ||||||
| ) -> str: |  | ||||||
|     """PEP 517 hook `build_wheel`.""" |  | ||||||
|     args = ["build-backend", "build-wheel", wheel_directory] |  | ||||||
|     if metadata_directory: |  | ||||||
|         args.extend(["--metadata-directory", metadata_directory]) |  | ||||||
|     return call(args, config_settings) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_requires_for_build_sdist( |  | ||||||
|     config_settings: "Mapping[Any, Any] | None" = None, |  | ||||||
| ) -> "Sequence[str]": |  | ||||||
|     """PEP 517 hook `get_requires_for_build_sdist`.""" |  | ||||||
|     warn_config_settings(config_settings) |  | ||||||
|     return [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_requires_for_build_wheel( |  | ||||||
|     config_settings: "Mapping[Any, Any] | None" = None, |  | ||||||
| ) -> "Sequence[str]": |  | ||||||
|     """PEP 517 hook `get_requires_for_build_wheel`.""" |  | ||||||
|     warn_config_settings(config_settings) |  | ||||||
|     return [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def prepare_metadata_for_build_wheel( |  | ||||||
|     metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None |  | ||||||
| ) -> str: |  | ||||||
|     """PEP 517 hook `prepare_metadata_for_build_wheel`.""" |  | ||||||
|     args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory] |  | ||||||
|     return call(args, config_settings) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_editable( |  | ||||||
|     wheel_directory: str, |  | ||||||
|     config_settings: "Mapping[Any, Any] | None" = None, |  | ||||||
|     metadata_directory: "str | None" = None, |  | ||||||
| ) -> str: |  | ||||||
|     """PEP 660 hook `build_editable`.""" |  | ||||||
|     args = ["build-backend", "build-editable", wheel_directory] |  | ||||||
|  |  | ||||||
|     if metadata_directory: |  | ||||||
|         args.extend(["--metadata-directory", metadata_directory]) |  | ||||||
|     return call(args, config_settings) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_requires_for_build_editable( |  | ||||||
|     config_settings: "Mapping[Any, Any] | None" = None, |  | ||||||
| ) -> "Sequence[str]": |  | ||||||
|     """PEP 660 hook `get_requires_for_build_editable`.""" |  | ||||||
|     warn_config_settings(config_settings) |  | ||||||
|     return [] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def prepare_metadata_for_build_editable( |  | ||||||
|     metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None |  | ||||||
| ) -> str: |  | ||||||
|     """PEP 660 hook `prepare_metadata_for_build_editable`.""" |  | ||||||
|     args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory] |  | ||||||
|     return call(args, config_settings) |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| import importlib |  | ||||||
| import platform |  | ||||||
| import sys |  | ||||||
| import os |  | ||||||
|  |  | ||||||
| import torch |  | ||||||
| from huggingface_hub import hf_hub_download, snapshot_download |  | ||||||
| from packaging.version import parse |  | ||||||
|  |  | ||||||
| if sys.version_info >= (3, 11): |  | ||||||
|     import tomllib |  | ||||||
| else: |  | ||||||
|     import tomli as tomllib |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def build_variant(): |  | ||||||
|     torch_version = parse(torch.__version__) |  | ||||||
|     cuda_version = parse(torch.version.cuda) |  | ||||||
|     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" |  | ||||||
|     cpu = platform.machine() |  | ||||||
|     os = platform.system().lower() |  | ||||||
|  |  | ||||||
|     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def import_from_path(module_name: str, file_path): |  | ||||||
|     spec = importlib.util.spec_from_file_location(module_name, file_path) |  | ||||||
|     module = importlib.util.module_from_spec(spec) |  | ||||||
|     sys.modules[module_name] = module |  | ||||||
|     spec.loader.exec_module(module) |  | ||||||
|     return module |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def install_kernel(repo_id: str, revision: str): |  | ||||||
|     package_name = get_metadata(repo_id)["torch"]["name"] |  | ||||||
|     repo_path = snapshot_download( |  | ||||||
|         repo_id, allow_patterns=f"build/{build_variant()}/*", revision=revision |  | ||||||
|     ) |  | ||||||
|     return package_name, f"{repo_path}/build/{build_variant()}" |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_metadata(repo_id: str): |  | ||||||
|     with open(hf_hub_download(repo_id, "build.toml"), "rb") as f: |  | ||||||
|         return tomllib.load(f) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_kernel(repo_id: str, revision: str = "main"): |  | ||||||
|     package_name, package_path = install_kernel(repo_id, revision=revision) |  | ||||||
|     return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_kernel(repo_id: str, revision: str = "main"): |  | ||||||
|     filename = hf_hub_download( |  | ||||||
|         repo_id, "build.toml", local_files_only=True, revision=revision |  | ||||||
|     ) |  | ||||||
|     with open(filename, "rb") as f: |  | ||||||
|         metadata = tomllib.load(f) |  | ||||||
|     package_name = metadata["torch"]["name"] |  | ||||||
|     repo_path = os.path.dirname(filename) |  | ||||||
|     package_path = f"{repo_path}/build/{build_variant()}" |  | ||||||
|     return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py") |  | ||||||
							
								
								
									
										23
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								src/kernels/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | |||||||
|  | from kernels.layer import ( | ||||||
|  |     Device, | ||||||
|  |     LayerRepository, | ||||||
|  |     register_kernel_mapping, | ||||||
|  |     use_kernel_forward_from_hub, | ||||||
|  | ) | ||||||
|  | from kernels.utils import ( | ||||||
|  |     get_kernel, | ||||||
|  |     get_locked_kernel, | ||||||
|  |     install_kernel, | ||||||
|  |     load_kernel, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | __all__ = [ | ||||||
|  |     "get_kernel", | ||||||
|  |     "get_locked_kernel", | ||||||
|  |     "load_kernel", | ||||||
|  |     "install_kernel", | ||||||
|  |     "use_kernel_forward_from_hub", | ||||||
|  |     "register_kernel_mapping", | ||||||
|  |     "LayerRepository", | ||||||
|  |     "Device", | ||||||
|  | ] | ||||||
							
								
								
									
										98
									
								
								src/kernels/cli.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								src/kernels/cli.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,98 @@ | |||||||
|  | import argparse | ||||||
|  | import dataclasses | ||||||
|  | import json | ||||||
|  | import sys | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from kernels.compat import tomllib | ||||||
|  | from kernels.lockfile import KernelLock, get_kernel_locks | ||||||
|  | from kernels.utils import install_kernel, install_kernel_all_variants | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     parser = argparse.ArgumentParser( | ||||||
|  |         prog="kernel", description="Manage compute kernels" | ||||||
|  |     ) | ||||||
|  |     subparsers = parser.add_subparsers(required=True) | ||||||
|  |  | ||||||
|  |     download_parser = subparsers.add_parser("download", help="Download locked kernels") | ||||||
|  |     download_parser.add_argument( | ||||||
|  |         "project_dir", | ||||||
|  |         type=Path, | ||||||
|  |         help="The project directory", | ||||||
|  |     ) | ||||||
|  |     download_parser.add_argument( | ||||||
|  |         "--all-variants", | ||||||
|  |         action="store_true", | ||||||
|  |         help="Download all build variants of the kernel", | ||||||
|  |     ) | ||||||
|  |     download_parser.set_defaults(func=download_kernels) | ||||||
|  |  | ||||||
|  |     lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") | ||||||
|  |     lock_parser.add_argument( | ||||||
|  |         "project_dir", | ||||||
|  |         type=Path, | ||||||
|  |         help="The project directory", | ||||||
|  |     ) | ||||||
|  |     lock_parser.set_defaults(func=lock_kernels) | ||||||
|  |  | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     args.func(args) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def download_kernels(args): | ||||||
|  |     lock_path = args.project_dir / "kernels.lock" | ||||||
|  |  | ||||||
|  |     if not lock_path.exists(): | ||||||
|  |         print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr) | ||||||
|  |         sys.exit(1) | ||||||
|  |  | ||||||
|  |     with open(args.project_dir / "kernels.lock", "r") as f: | ||||||
|  |         lock_json = json.load(f) | ||||||
|  |  | ||||||
|  |     all_successful = True | ||||||
|  |  | ||||||
|  |     for kernel_lock_json in lock_json: | ||||||
|  |         kernel_lock = KernelLock.from_json(kernel_lock_json) | ||||||
|  |         print( | ||||||
|  |             f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}", | ||||||
|  |             file=sys.stderr, | ||||||
|  |         ) | ||||||
|  |         if args.all_variants: | ||||||
|  |             install_kernel_all_variants( | ||||||
|  |                 kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             try: | ||||||
|  |                 install_kernel( | ||||||
|  |                     kernel_lock.repo_id, | ||||||
|  |                     kernel_lock.sha, | ||||||
|  |                     variant_locks=kernel_lock.variants, | ||||||
|  |                 ) | ||||||
|  |             except FileNotFoundError as e: | ||||||
|  |                 print(e, file=sys.stderr) | ||||||
|  |                 all_successful = False | ||||||
|  |  | ||||||
|  |     if not all_successful: | ||||||
|  |         sys.exit(1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def lock_kernels(args): | ||||||
|  |     with open(args.project_dir / "pyproject.toml", "rb") as f: | ||||||
|  |         data = tomllib.load(f) | ||||||
|  |  | ||||||
|  |     kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None) | ||||||
|  |  | ||||||
|  |     all_locks = [] | ||||||
|  |     for kernel, version in kernel_versions.items(): | ||||||
|  |         all_locks.append(get_kernel_locks(kernel, version)) | ||||||
|  |  | ||||||
|  |     with open(args.project_dir / "kernels.lock", "w") as f: | ||||||
|  |         json.dump(all_locks, f, cls=_JSONEncoder, indent=2) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class _JSONEncoder(json.JSONEncoder): | ||||||
|  |     def default(self, o): | ||||||
|  |         if dataclasses.is_dataclass(o): | ||||||
|  |             return dataclasses.asdict(o) | ||||||
|  |         return super().default(o) | ||||||
							
								
								
									
										8
									
								
								src/kernels/compat.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								src/kernels/compat.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | |||||||
|  | import sys | ||||||
|  |  | ||||||
|  | if sys.version_info >= (3, 11): | ||||||
|  |     import tomllib | ||||||
|  | else: | ||||||
|  |     import tomli as tomllib | ||||||
|  |  | ||||||
|  | __all__ = ["tomllib"] | ||||||
							
								
								
									
										231
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										231
									
								
								src/kernels/layer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,231 @@ | |||||||
|  | import inspect | ||||||
|  | from contextvars import ContextVar | ||||||
|  | from copy import deepcopy | ||||||
|  | from dataclasses import dataclass, field | ||||||
|  | from typing import TYPE_CHECKING, Callable, Dict, Union | ||||||
|  |  | ||||||
|  | from .utils import get_kernel | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from torch import nn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass(frozen=True) | ||||||
|  | class Device: | ||||||
|  |     type: str | ||||||
|  |  | ||||||
|  |     # In the future we might add compute capabilities, etc. | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return isinstance(other, Device) and self.type == other.type | ||||||
|  |  | ||||||
|  |     def __hash__(self): | ||||||
|  |         return hash(self.type) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class LayerRepository: | ||||||
|  |     """ | ||||||
|  |     Repository and name of a layer. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     layer_name: str = field( | ||||||
|  |         metadata={"help": "The name of the layer in the kernel repository."} | ||||||
|  |     ) | ||||||
|  |     repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."}) | ||||||
|  |     revision: str = field( | ||||||
|  |         default="main", metadata={"help": "The revision of the layer."} | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return ( | ||||||
|  |             isinstance(other, LayerRepository) | ||||||
|  |             and self.layer_name == other.layer_name | ||||||
|  |             and self.repo_id == other.repo_id | ||||||
|  |             and self.revision == other.revision | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __hash__(self): | ||||||
|  |         return hash((self.layer_name, self.repo_id, self.revision)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar( | ||||||
|  |     "_KERNEL_MAPPING", default={} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def use_kernel_mapping(mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]): | ||||||
|  |     class ContextManager: | ||||||
|  |         def __enter__(self): | ||||||
|  |             # Mappings always stack on previous mappings. | ||||||
|  |             self.token = _KERNEL_MAPPING.set(deepcopy(_KERNEL_MAPPING.get())) | ||||||
|  |             register_kernel_mapping(mapping) | ||||||
|  |  | ||||||
|  |         def __exit__(self, exc_type, exc_value, traceback): | ||||||
|  |             _KERNEL_MAPPING.reset(self.token) | ||||||
|  |  | ||||||
|  |     return ContextManager() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def register_kernel_mapping( | ||||||
|  |     mapping: Dict[str, Dict[Union[Device, str], LayerRepository]] | ||||||
|  | ): | ||||||
|  |     """ | ||||||
|  |     Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device. | ||||||
|  |     This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname. | ||||||
|  |     Exemple usage: | ||||||
|  |  | ||||||
|  |     ```python | ||||||
|  |     from kernels import LayerRepository, register_kernel_mapping | ||||||
|  |  | ||||||
|  |     kernel_layer_mapping = { | ||||||
|  |       "LlamaRMSNorm": { | ||||||
|  |           "cuda": LayerRepository( | ||||||
|  |               repo_id="kernels-community/activation", | ||||||
|  |               layer_name="RmsNorm", | ||||||
|  |               revision="layers", | ||||||
|  |           ), | ||||||
|  |       }, | ||||||
|  |     } | ||||||
|  |     register_kernel_mapping(kernel_layer_mapping) | ||||||
|  |     ``` | ||||||
|  |     """ | ||||||
|  |     # Merge with existing mappings. | ||||||
|  |     for new_kernel, new_device_repos in mapping.items(): | ||||||
|  |         device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {}) | ||||||
|  |         for new_device, new_repo in new_device_repos.items(): | ||||||
|  |             if isinstance(new_device, str): | ||||||
|  |                 device_repo[Device(type=new_device)] = new_repo | ||||||
|  |             else: | ||||||
|  |                 device_repo[new_device] = new_repo | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True): | ||||||
|  |     """ | ||||||
|  |     Replace the forward function of a layer using a layer from the kernel hub. | ||||||
|  |     This function monkeypatches a layer, replacing the `forward` method | ||||||
|  |     of the layer with that of a layer from the hub. The replacement is done | ||||||
|  |     when a layer matching `layer_name` and device type is registered through | ||||||
|  |     `register_layer_mapping`. The device type is inferred from the first | ||||||
|  |     argument to `forward`. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     fallback_forward = cls.forward | ||||||
|  |  | ||||||
|  |     cached_forward: Dict[LayerRepository, Callable] = {} | ||||||
|  |  | ||||||
|  |     def forward(self, x, **args): | ||||||
|  |         kernel = _KERNEL_MAPPING.get().get(layer_name) | ||||||
|  |         if kernel is None: | ||||||
|  |             if not use_fallback: | ||||||
|  |                 raise ValueError(f"No layer mapping for `{layer_name}`") | ||||||
|  |             return fallback_forward(self, x, **args) | ||||||
|  |  | ||||||
|  |         device = getattr(x, "device", None) | ||||||
|  |         if device is None: | ||||||
|  |             return fallback_forward(self, x, **args) | ||||||
|  |  | ||||||
|  |         repo = kernel.get(Device(type=device.type)) | ||||||
|  |         if repo is None: | ||||||
|  |             if not use_fallback: | ||||||
|  |                 raise ValueError( | ||||||
|  |                     f"No layer mapping for `{layer_name}` with device type `{device.type}`" | ||||||
|  |                 ) | ||||||
|  |             return fallback_forward(self, x, **args) | ||||||
|  |  | ||||||
|  |         # Short-circuit if we already loaded the layer. | ||||||
|  |         layer_forward = cached_forward.get(repo, None) | ||||||
|  |         if layer_forward is not None: | ||||||
|  |             return layer_forward(self, x, **args) | ||||||
|  |  | ||||||
|  |         layer = _get_kernel_layer( | ||||||
|  |             repo_id=repo.repo_id, | ||||||
|  |             layer_name=repo.layer_name, | ||||||
|  |             revision=repo.revision, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |         # We have to validate against the original signature. | ||||||
|  |         orig_forward = cls.forward | ||||||
|  |         try: | ||||||
|  |             cls.forward = fallback_forward | ||||||
|  |             _validate_layer(check_cls=cls, cls=layer) | ||||||
|  |         finally: | ||||||
|  |             cls.forward = orig_forward | ||||||
|  |  | ||||||
|  |         layer_forward = layer.forward | ||||||
|  |         cached_forward[repo] = layer_forward | ||||||
|  |  | ||||||
|  |         return layer_forward(self, x, **args) | ||||||
|  |  | ||||||
|  |     cls.forward = forward | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True): | ||||||
|  |     """ | ||||||
|  |     Replace the forward function of a layer using a layer from the kernel hub. | ||||||
|  |     This decorator can be applied to a layer and replaces the forward method | ||||||
|  |     of the layer with that of a layer from the hub. The replacement is done | ||||||
|  |     when a layer matching `layer_name` and device type is registered through | ||||||
|  |     `register_layer_mapping`. The device type is inferred from the first | ||||||
|  |     argument to `forward`. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def decorator(cls): | ||||||
|  |         replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback) | ||||||
|  |         return cls | ||||||
|  |  | ||||||
|  |     return decorator | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module": | ||||||
|  |     """Get a layer from a kernel.""" | ||||||
|  |  | ||||||
|  |     kernel = get_kernel(repo_id, revision=revision) | ||||||
|  |  | ||||||
|  |     if getattr(kernel, "layers", None) is None: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Kernel `{repo_id}` at revision `{revision}` does not define any layers." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     layer = getattr(kernel.layers, layer_name, None) | ||||||
|  |     if layer is None: | ||||||
|  |         raise ValueError(f"Layer `{layer_name}` not found in kernel `{repo_id}`.") | ||||||
|  |     return layer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _validate_layer(*, check_cls, cls): | ||||||
|  |     # The layer must have at least have the following properties: (1) it | ||||||
|  |     # must be stateless; (2) the forward signature should correspond to | ||||||
|  |     # the signature it is replacing; (3) forward should not call other | ||||||
|  |     # methods. | ||||||
|  |  | ||||||
|  |     from torch import nn | ||||||
|  |  | ||||||
|  |     if not issubclass(cls, nn.Module): | ||||||
|  |         raise TypeError(f"Layer `{cls}` is not a Torch layer.") | ||||||
|  |  | ||||||
|  |     # We verify statelessness by checking that the does not have its own | ||||||
|  |     # constructor (since the constructor could add member variables)... | ||||||
|  |     if cls.__init__ is not nn.Module.__init__: | ||||||
|  |         raise TypeError("Layer must not override nn.Module constructor.") | ||||||
|  |  | ||||||
|  |     # ... or predefined member variables. | ||||||
|  |     torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)} | ||||||
|  |     cls_members = {name for name, _ in inspect.getmembers(cls)} | ||||||
|  |     if cls_members - torch_module_members != set(): | ||||||
|  |         raise TypeError("Layer must not contain additional members.") | ||||||
|  |  | ||||||
|  |     # Check whether the forward signatures are similar. | ||||||
|  |     params = inspect.signature(cls.forward).parameters | ||||||
|  |     ref_params = inspect.signature(check_cls.forward).parameters | ||||||
|  |  | ||||||
|  |     if len(params) != len(ref_params): | ||||||
|  |         raise TypeError( | ||||||
|  |             "Forward signature does not match: different number of arguments." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     for param, ref_param in zip(params.values(), ref_params.values()): | ||||||
|  |         if param.kind != ref_param.kind: | ||||||
|  |             raise TypeError( | ||||||
|  |                 f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})" | ||||||
|  |             ) | ||||||
							
								
								
									
										135
									
								
								src/kernels/lockfile.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								src/kernels/lockfile.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,135 @@ | |||||||
|  | import hashlib | ||||||
|  | from dataclasses import dataclass | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import Dict, List, Tuple | ||||||
|  |  | ||||||
|  | from huggingface_hub import HfApi | ||||||
|  | from huggingface_hub.hf_api import GitRefInfo | ||||||
|  | from packaging.specifiers import SpecifierSet | ||||||
|  | from packaging.version import InvalidVersion, Version | ||||||
|  |  | ||||||
|  | from kernels.compat import tomllib | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class VariantLock: | ||||||
|  |     hash: str | ||||||
|  |     hash_type: str = "git_lfs_concat" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @dataclass | ||||||
|  | class KernelLock: | ||||||
|  |     repo_id: str | ||||||
|  |     sha: str | ||||||
|  |     variants: Dict[str, VariantLock] | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_json(cls, o: Dict): | ||||||
|  |         variants = { | ||||||
|  |             variant: VariantLock(**lock) for variant, lock in o["variants"].items() | ||||||
|  |         } | ||||||
|  |         return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: | ||||||
|  |     """Get kernel versions that are available in the repository.""" | ||||||
|  |     versions = {} | ||||||
|  |     for tag in HfApi().list_repo_refs(repo_id).tags: | ||||||
|  |         if not tag.name.startswith("v"): | ||||||
|  |             continue | ||||||
|  |         try: | ||||||
|  |             versions[Version(tag.name[1:])] = tag | ||||||
|  |         except InvalidVersion: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|  |     return versions | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: | ||||||
|  |     """ | ||||||
|  |     Get the locks for a kernel with the given version spec. | ||||||
|  |  | ||||||
|  |     The version specifier can be any valid Python version specifier: | ||||||
|  |     https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers | ||||||
|  |     """ | ||||||
|  |     versions = _get_available_versions(repo_id) | ||||||
|  |     requirement = SpecifierSet(version_spec) | ||||||
|  |     accepted_versions = sorted(requirement.filter(versions.keys())) | ||||||
|  |  | ||||||
|  |     if len(accepted_versions) == 0: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"No version of `{repo_id}` satisfies requirement: {version_spec}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     tag_for_newest = versions[accepted_versions[-1]] | ||||||
|  |  | ||||||
|  |     r = HfApi().repo_info( | ||||||
|  |         repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True | ||||||
|  |     ) | ||||||
|  |     if r.sha is None: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     if r.siblings is None: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     variant_files: Dict[str, List[Tuple[bytes, str]]] = {} | ||||||
|  |     for sibling in r.siblings: | ||||||
|  |         if sibling.rfilename.startswith("build/torch"): | ||||||
|  |             if sibling.blob_id is None: | ||||||
|  |                 raise ValueError(f"Cannot get blob ID for {sibling.rfilename}") | ||||||
|  |  | ||||||
|  |             path = Path(sibling.rfilename) | ||||||
|  |             variant = path.parts[1] | ||||||
|  |             filename = Path(*path.parts[2:]) | ||||||
|  |  | ||||||
|  |             hash = sibling.lfs.sha256 if sibling.lfs is not None else sibling.blob_id | ||||||
|  |  | ||||||
|  |             files = variant_files.setdefault(variant, []) | ||||||
|  |  | ||||||
|  |             # Encode as posix for consistent slash handling, then encode | ||||||
|  |             # as utf-8 for byte-wise sorting later. | ||||||
|  |             files.append((filename.as_posix().encode("utf-8"), hash)) | ||||||
|  |  | ||||||
|  |     variant_locks = {} | ||||||
|  |     for variant, files in variant_files.items(): | ||||||
|  |         m = hashlib.sha256() | ||||||
|  |         for filename_bytes, hash in sorted(files): | ||||||
|  |             # Filename as bytes. | ||||||
|  |             m.update(filename_bytes) | ||||||
|  |             # Git blob or LFS file hash as bytes. | ||||||
|  |             m.update(bytes.fromhex(hash)) | ||||||
|  |  | ||||||
|  |         variant_locks[variant] = VariantLock(hash=f"sha256-{m.hexdigest()}") | ||||||
|  |  | ||||||
|  |     return KernelLock(repo_id=repo_id, sha=r.sha, variants=variant_locks) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_egg_lockfile(cmd, basename, filename): | ||||||
|  |     import logging | ||||||
|  |  | ||||||
|  |     cwd = Path.cwd() | ||||||
|  |     pyproject_path = cwd / "pyproject.toml" | ||||||
|  |     if not pyproject_path.exists(): | ||||||
|  |         # Nothing to do if the project doesn't have pyproject.toml. | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     with open(pyproject_path, "rb") as f: | ||||||
|  |         data = tomllib.load(f) | ||||||
|  |  | ||||||
|  |     kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None) | ||||||
|  |     if kernel_versions is None: | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     lock_path = cwd / "kernels.lock" | ||||||
|  |     if not lock_path.exists(): | ||||||
|  |         logging.warning(f"Lock file {lock_path} does not exist") | ||||||
|  |         # Ensure that the file gets deleted in editable installs. | ||||||
|  |         data = None | ||||||
|  |     else: | ||||||
|  |         data = open(lock_path, "r").read() | ||||||
|  |  | ||||||
|  |     cmd.write_or_delete_file(basename, filename, data) | ||||||
							
								
								
									
										308
									
								
								src/kernels/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								src/kernels/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,308 @@ | |||||||
|  | import ctypes | ||||||
|  | import hashlib | ||||||
|  | import importlib | ||||||
|  | import importlib.metadata | ||||||
|  | import inspect | ||||||
|  | import json | ||||||
|  | import os | ||||||
|  | import platform | ||||||
|  | import sys | ||||||
|  | from importlib.metadata import Distribution | ||||||
|  | from pathlib import Path | ||||||
|  | from types import ModuleType | ||||||
|  | from typing import Dict, List, Optional, Tuple | ||||||
|  |  | ||||||
|  | from huggingface_hub import snapshot_download | ||||||
|  | from packaging.version import parse | ||||||
|  |  | ||||||
|  | from kernels.lockfile import KernelLock, VariantLock | ||||||
|  |  | ||||||
|  | CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_variant() -> str: | ||||||
|  |     import torch | ||||||
|  |  | ||||||
|  |     if torch.version.cuda is None: | ||||||
|  |         raise AssertionError( | ||||||
|  |             "This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     torch_version = parse(torch.__version__) | ||||||
|  |     cuda_version = parse(torch.version.cuda) | ||||||
|  |     cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" | ||||||
|  |     cpu = platform.machine() | ||||||
|  |     os = platform.system().lower() | ||||||
|  |  | ||||||
|  |     return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def universal_build_variant() -> str: | ||||||
|  |     # Once we support other frameworks, detection goes here. | ||||||
|  |     return "torch-universal" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def import_from_path(module_name: str, file_path: Path) -> ModuleType: | ||||||
|  |     # We cannot use the module name as-is, after adding it to `sys.modules`, | ||||||
|  |     # it would also be used for other imports. So, we make a module name that | ||||||
|  |     # depends on the path for it to be unique using the hex-encoded hash of | ||||||
|  |     # the path. | ||||||
|  |     path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value) | ||||||
|  |     module_name = f"{module_name}_{path_hash}" | ||||||
|  |     spec = importlib.util.spec_from_file_location(module_name, file_path) | ||||||
|  |     if spec is None: | ||||||
|  |         raise ImportError(f"Cannot load spec for {module_name} from {file_path}") | ||||||
|  |     module = importlib.util.module_from_spec(spec) | ||||||
|  |     if module is None: | ||||||
|  |         raise ImportError(f"Cannot load module {module_name} from spec") | ||||||
|  |     sys.modules[module_name] = module | ||||||
|  |     spec.loader.exec_module(module)  # type: ignore | ||||||
|  |     return module | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def install_kernel( | ||||||
|  |     repo_id: str, | ||||||
|  |     revision: str, | ||||||
|  |     local_files_only: bool = False, | ||||||
|  |     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||||
|  | ) -> Tuple[str, Path]: | ||||||
|  |     """ | ||||||
|  |     Download a kernel for the current environment to the cache. | ||||||
|  |  | ||||||
|  |     The output path is validated againt `hash` when set. | ||||||
|  |     """ | ||||||
|  |     package_name = package_name_from_repo_id(repo_id) | ||||||
|  |     variant = build_variant() | ||||||
|  |     universal_variant = universal_build_variant() | ||||||
|  |     repo_path = Path( | ||||||
|  |         snapshot_download( | ||||||
|  |             repo_id, | ||||||
|  |             allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"], | ||||||
|  |             cache_dir=CACHE_DIR, | ||||||
|  |             revision=revision, | ||||||
|  |             local_files_only=local_files_only, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     variant_path = repo_path / "build" / variant | ||||||
|  |     universal_variant_path = repo_path / "build" / universal_variant | ||||||
|  |  | ||||||
|  |     if not variant_path.exists() and universal_variant_path.exists(): | ||||||
|  |         # Fall back to universal variant. | ||||||
|  |         variant = universal_variant | ||||||
|  |         variant_path = universal_variant_path | ||||||
|  |  | ||||||
|  |     if variant_locks is not None: | ||||||
|  |         variant_lock = variant_locks.get(variant) | ||||||
|  |         if variant_lock is None: | ||||||
|  |             raise ValueError(f"No lock found for build variant: {variant}") | ||||||
|  |         validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash) | ||||||
|  |  | ||||||
|  |     module_init_path = variant_path / package_name / "__init__.py" | ||||||
|  |  | ||||||
|  |     if not os.path.exists(module_init_path): | ||||||
|  |         raise FileNotFoundError( | ||||||
|  |             f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     return package_name, variant_path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def install_kernel_all_variants( | ||||||
|  |     repo_id: str, | ||||||
|  |     revision: str, | ||||||
|  |     local_files_only: bool = False, | ||||||
|  |     variant_locks: Optional[Dict[str, VariantLock]] = None, | ||||||
|  | ) -> Path: | ||||||
|  |     repo_path = Path( | ||||||
|  |         snapshot_download( | ||||||
|  |             repo_id, | ||||||
|  |             allow_patterns="build/*", | ||||||
|  |             cache_dir=CACHE_DIR, | ||||||
|  |             revision=revision, | ||||||
|  |             local_files_only=local_files_only, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     if variant_locks is not None: | ||||||
|  |         for entry in (repo_path / "build").iterdir(): | ||||||
|  |             variant = entry.parts[-1] | ||||||
|  |  | ||||||
|  |             variant_lock = variant_locks.get(variant) | ||||||
|  |             if variant_lock is None: | ||||||
|  |                 raise ValueError(f"No lock found for build variant: {variant}") | ||||||
|  |  | ||||||
|  |             validate_kernel( | ||||||
|  |                 repo_path=repo_path, variant=variant, hash=variant_lock.hash | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     return repo_path / "build" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_kernel(repo_id: str, revision: str = "main") -> ModuleType: | ||||||
|  |     package_name, package_path = install_kernel(repo_id, revision=revision) | ||||||
|  |     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: | ||||||
|  |     """ | ||||||
|  |     Get a pre-downloaded, locked kernel. | ||||||
|  |  | ||||||
|  |     If `lockfile` is not specified, the lockfile will be loaded from the | ||||||
|  |     caller's package metadata. | ||||||
|  |     """ | ||||||
|  |     if lockfile is None: | ||||||
|  |         locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
|  |     else: | ||||||
|  |         with open(lockfile, "r") as f: | ||||||
|  |             locked_sha = _get_locked_kernel(repo_id, f.read()) | ||||||
|  |  | ||||||
|  |     if locked_sha is None: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project." | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     package_name = package_name_from_repo_id(repo_id) | ||||||
|  |  | ||||||
|  |     variant = build_variant() | ||||||
|  |     universal_variant = universal_build_variant() | ||||||
|  |  | ||||||
|  |     repo_path = Path( | ||||||
|  |         snapshot_download( | ||||||
|  |             repo_id, | ||||||
|  |             allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"], | ||||||
|  |             cache_dir=CACHE_DIR, | ||||||
|  |             revision=locked_sha, | ||||||
|  |             local_files_only=True, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     variant_path = repo_path / "build" / variant | ||||||
|  |     universal_variant_path = repo_path / "build" / universal_variant | ||||||
|  |     if not variant_path.exists() and universal_variant_path.exists(): | ||||||
|  |         # Fall back to universal variant. | ||||||
|  |         variant = universal_variant | ||||||
|  |         variant_path = universal_variant_path | ||||||
|  |  | ||||||
|  |     module_init_path = variant_path / package_name / "__init__.py" | ||||||
|  |     if not os.path.exists(module_init_path): | ||||||
|  |         raise FileNotFoundError( | ||||||
|  |             f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     return import_from_path(package_name, variant_path / package_name / "__init__.py") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: | ||||||
|  |     """Get a kernel using a lock file.""" | ||||||
|  |     locked_sha = _get_caller_locked_kernel(repo_id) | ||||||
|  |  | ||||||
|  |     if locked_sha is None: | ||||||
|  |         raise ValueError(f"Kernel `{repo_id}` is not locked") | ||||||
|  |  | ||||||
|  |     package_name, package_path = install_kernel( | ||||||
|  |         repo_id, locked_sha, local_files_only=local_files_only | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     return import_from_path(package_name, package_path / package_name / "__init__.py") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_caller_locked_kernel(repo_id: str) -> Optional[str]: | ||||||
|  |     for dist in _get_caller_distributions(): | ||||||
|  |         lock_json = dist.read_text("kernels.lock") | ||||||
|  |         if lock_json is None: | ||||||
|  |             continue | ||||||
|  |         locked_sha = _get_locked_kernel(repo_id, lock_json) | ||||||
|  |         if locked_sha is not None: | ||||||
|  |             return locked_sha | ||||||
|  |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_locked_kernel(repo_id: str, lock_json: str) -> Optional[str]: | ||||||
|  |     for kernel_lock_json in json.loads(lock_json): | ||||||
|  |         kernel_lock = KernelLock.from_json(kernel_lock_json) | ||||||
|  |         if kernel_lock.repo_id == repo_id: | ||||||
|  |             return kernel_lock.sha | ||||||
|  |     return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_caller_distributions() -> List[Distribution]: | ||||||
|  |     module = _get_caller_module() | ||||||
|  |     if module is None: | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|  |     # Look up all possible distributions that this module could be from. | ||||||
|  |     package = module.__name__.split(".")[0] | ||||||
|  |     dist_names = importlib.metadata.packages_distributions().get(package) | ||||||
|  |     if dist_names is None: | ||||||
|  |         return [] | ||||||
|  |  | ||||||
|  |     return [importlib.metadata.distribution(dist_name) for dist_name in dist_names] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_caller_module() -> Optional[ModuleType]: | ||||||
|  |     stack = inspect.stack() | ||||||
|  |     # Get first module in the stack that is not the current module. | ||||||
|  |     first_module = inspect.getmodule(stack[0][0]) | ||||||
|  |     for frame in stack[1:]: | ||||||
|  |         module = inspect.getmodule(frame[0]) | ||||||
|  |         if module is not None and module != first_module: | ||||||
|  |             return module | ||||||
|  |     return first_module | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def validate_kernel(*, repo_path: Path, variant: str, hash: str): | ||||||
|  |     """Validate the given build variant of a kernel against a hasht.""" | ||||||
|  |     variant_path = repo_path / "build" / variant | ||||||
|  |  | ||||||
|  |     # Get the file paths. The first element is a byte-encoded relative path | ||||||
|  |     # used for sorting. The second element is the absolute path. | ||||||
|  |     files: List[Tuple[bytes, Path]] = [] | ||||||
|  |     # Ideally we'd use Path.walk, but it's only available in Python 3.12. | ||||||
|  |     for dirpath, _, filenames in os.walk(variant_path): | ||||||
|  |         for filename in filenames: | ||||||
|  |             file_abs = Path(dirpath) / filename | ||||||
|  |  | ||||||
|  |             # Python likes to create files when importing modules from the | ||||||
|  |             # cache, only hash files that are symlinked blobs. | ||||||
|  |             if file_abs.is_symlink(): | ||||||
|  |                 files.append( | ||||||
|  |                     ( | ||||||
|  |                         file_abs.relative_to(variant_path).as_posix().encode("utf-8"), | ||||||
|  |                         file_abs, | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |     m = hashlib.sha256() | ||||||
|  |  | ||||||
|  |     for filename_bytes, full_path in sorted(files): | ||||||
|  |         m.update(filename_bytes) | ||||||
|  |  | ||||||
|  |         blob_filename = full_path.resolve().name | ||||||
|  |         if len(blob_filename) == 40: | ||||||
|  |             # SHA-1 hashed, so a Git blob. | ||||||
|  |             m.update(git_hash_object(full_path.read_bytes())) | ||||||
|  |         elif len(blob_filename) == 64: | ||||||
|  |             # SHA-256 hashed, so a Git LFS blob. | ||||||
|  |             m.update(hashlib.sha256(full_path.read_bytes()).digest()) | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Unexpected blob filename length: {len(blob_filename)}") | ||||||
|  |  | ||||||
|  |     computedHash = f"sha256-{m.hexdigest()}" | ||||||
|  |     if computedHash != hash: | ||||||
|  |         raise ValueError( | ||||||
|  |             f"Lock file specifies kernel with hash {hash}, but downloaded kernel has hash: {computedHash}" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def git_hash_object(data: bytes, object_type: str = "blob"): | ||||||
|  |     """Calculate git SHA1 of data.""" | ||||||
|  |     header = f"{object_type} {len(data)}\0".encode() | ||||||
|  |     m = hashlib.sha1() | ||||||
|  |     m.update(header) | ||||||
|  |     m.update(data) | ||||||
|  |     return m.digest() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def package_name_from_repo_id(repo_id: str) -> str: | ||||||
|  |     return repo_id.split("/")[-1].replace("-", "_") | ||||||
							
								
								
									
										66
									
								
								tests/kernel_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								tests/kernel_locking/kernels.lock
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | |||||||
|  | [ | ||||||
|  |   { | ||||||
|  |     "repo_id": "kernels-community/activation", | ||||||
|  |     "sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0", | ||||||
|  |     "variants": { | ||||||
|  |       "torch25-cxx11-cu118-x86_64-linux": { | ||||||
|  |         "hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch25-cxx11-cu121-x86_64-linux": { | ||||||
|  |         "hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch25-cxx11-cu124-x86_64-linux": { | ||||||
|  |         "hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch25-cxx98-cu118-x86_64-linux": { | ||||||
|  |         "hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch25-cxx98-cu121-x86_64-linux": { | ||||||
|  |         "hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch25-cxx98-cu124-x86_64-linux": { | ||||||
|  |         "hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx11-cu118-x86_64-linux": { | ||||||
|  |         "hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx11-cu124-x86_64-linux": { | ||||||
|  |         "hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx11-cu126-x86_64-linux": { | ||||||
|  |         "hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx98-cu118-x86_64-linux": { | ||||||
|  |         "hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx98-cu124-x86_64-linux": { | ||||||
|  |         "hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       }, | ||||||
|  |       "torch26-cxx98-cu126-x86_64-linux": { | ||||||
|  |         "hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |     "repo_id": "kernels-community/triton-scaled-mm", | ||||||
|  |     "sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f", | ||||||
|  |     "variants": { | ||||||
|  |       "torch-universal": { | ||||||
|  |         "hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52", | ||||||
|  |         "hash_type": "git_lfs_concat" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | ] | ||||||
							
								
								
									
										3
									
								
								tests/kernel_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tests/kernel_locking/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,3 @@ | |||||||
|  | [tool.kernels.dependencies] | ||||||
|  | "kernels-community/activation" = ">=0.0.2" | ||||||
|  | "kernels-community/triton-scaled-mm" = ">=0.0.2" | ||||||
| @ -1,6 +1,7 @@ | |||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
| from hf_kernels import get_kernel |  | ||||||
|  | from kernels import get_kernel | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @ -8,6 +9,11 @@ def kernel(): | |||||||
|     return get_kernel("kernels-community/activation") |     return get_kernel("kernels-community/activation") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture | ||||||
|  | def universal_kernel(): | ||||||
|  |     return get_kernel("kernels-community/triton-scaled-mm") | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def device(): | def device(): | ||||||
|     if not torch.cuda.is_available(): |     if not torch.cuda.is_available(): | ||||||
| @ -28,3 +34,17 @@ def test_gelu_fast(kernel, device): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     assert torch.allclose(y, expected) |     assert torch.allclose(y, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_universal_kernel(universal_kernel): | ||||||
|  |     torch.manual_seed(0) | ||||||
|  |     A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda") | ||||||
|  |     B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda") | ||||||
|  |     scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda") | ||||||
|  |     scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda") | ||||||
|  |  | ||||||
|  |     out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16) | ||||||
|  |     out_check = (A * scale_a) @ (B * scale_b) | ||||||
|  |     out_check = out_check.to(torch.float16) | ||||||
|  |  | ||||||
|  |     torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1) | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
| from hf_kernels import get_kernel |  | ||||||
|  | from kernels import get_kernel | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
|  | |||||||
							
								
								
									
										24
									
								
								tests/test_kernel_locking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								tests/test_kernel_locking.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | |||||||
|  | from dataclasses import dataclass | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | from kernels import load_kernel | ||||||
|  | from kernels.cli import download_kernels | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Mock download arguments class. | ||||||
|  | @dataclass | ||||||
|  | class DownloadArgs: | ||||||
|  |     all_variants: bool | ||||||
|  |     project_dir: Path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_download_all_hash_validation(): | ||||||
|  |     project_dir = Path(__file__).parent / "kernel_locking" | ||||||
|  |     download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_load_locked(): | ||||||
|  |     project_dir = Path(__file__).parent / "kernel_locking" | ||||||
|  |     # Also validates that hashing works correctly. | ||||||
|  |     download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) | ||||||
|  |     load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") | ||||||
							
								
								
									
										168
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								tests/test_layer.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,168 @@ | |||||||
|  | import pytest | ||||||
|  | import torch | ||||||
|  | import torch.nn as nn | ||||||
|  | from torch.nn import functional as F | ||||||
|  |  | ||||||
|  | from kernels import ( | ||||||
|  |     Device, | ||||||
|  |     LayerRepository, | ||||||
|  |     register_kernel_mapping, | ||||||
|  |     use_kernel_forward_from_hub, | ||||||
|  | ) | ||||||
|  | from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping | ||||||
|  |  | ||||||
|  | kernel_layer_mapping = { | ||||||
|  |     "SiluAndMul": { | ||||||
|  |         Device(type="cuda"): LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             revision="layers", | ||||||
|  |         ) | ||||||
|  |     }, | ||||||
|  |     "SiluAndMulStringDevice": { | ||||||
|  |         "cuda": LayerRepository( | ||||||
|  |             repo_id="kernels-community/activation", | ||||||
|  |             layer_name="SiluAndMul", | ||||||
|  |             revision="layers", | ||||||
|  |         ) | ||||||
|  |     }, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | register_kernel_mapping(kernel_layer_mapping) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SiluAndMul(nn.Module): | ||||||
|  |     def __init__(self): | ||||||
|  |         super().__init__() | ||||||
|  |         # Used to check that we called hub kernel. | ||||||
|  |         self.n_calls = 0 | ||||||
|  |  | ||||||
|  |     def forward(self, input: torch.Tensor) -> torch.Tensor: | ||||||
|  |         self.n_calls += 1 | ||||||
|  |         d = input.shape[-1] // 2 | ||||||
|  |         return F.silu(input[..., :d]) * input[..., d:] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @use_kernel_forward_from_hub("SiluAndMul") | ||||||
|  | class SiluAndMulWithKernel(SiluAndMul): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @use_kernel_forward_from_hub("SiluAndMulStringDevice") | ||||||
|  | class SiluAndMulStringDevice(SiluAndMul): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) | ||||||
|  | @pytest.mark.parametrize("device", ["cuda", "cpu"]) | ||||||
|  | def test_hub_forward(cls, device): | ||||||
|  |     torch.random.manual_seed(0) | ||||||
|  |  | ||||||
|  |     silu_and_mul = SiluAndMul() | ||||||
|  |     X = torch.randn((32, 64), device=device) | ||||||
|  |     Y = silu_and_mul(X) | ||||||
|  |  | ||||||
|  |     silu_and_mul_with_kernel = cls() | ||||||
|  |     Y_kernel = silu_and_mul_with_kernel(X) | ||||||
|  |  | ||||||
|  |     torch.testing.assert_close(Y_kernel, Y) | ||||||
|  |  | ||||||
|  |     assert silu_and_mul.n_calls == 1 | ||||||
|  |     if device == "cuda": | ||||||
|  |         assert silu_and_mul_with_kernel.n_calls == 0 | ||||||
|  |     else: | ||||||
|  |         assert silu_and_mul_with_kernel.n_calls == 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_layer_fallback_works(): | ||||||
|  |     @use_kernel_forward_from_hub("SiluAndMulNonExisting") | ||||||
|  |     class SiluAndMulWithKernelFallback(SiluAndMul): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     # Check that we don't raise an exception for a non-existing kernel. | ||||||
|  |     SiluAndMulWithKernelFallback() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_mapping_contexts(): | ||||||
|  |     assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"} | ||||||
|  |  | ||||||
|  |     extra_mapping1 = { | ||||||
|  |         "TestKernel": { | ||||||
|  |             Device(type="cuda"): LayerRepository( | ||||||
|  |                 repo_id="kernels-community/activation", | ||||||
|  |                 layer_name="SiluAndMul", | ||||||
|  |                 revision="layers", | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     with use_kernel_mapping(extra_mapping1): | ||||||
|  |         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|  |             "SiluAndMul", | ||||||
|  |             "SiluAndMulStringDevice", | ||||||
|  |             "TestKernel", | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         extra_mapping2 = { | ||||||
|  |             "SiluAndMul": { | ||||||
|  |                 Device(type="cuda"): LayerRepository( | ||||||
|  |                     repo_id="kernels-community/non-existing", | ||||||
|  |                     layer_name="SiluAndMul", | ||||||
|  |                     revision="layers", | ||||||
|  |                 ) | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         with use_kernel_mapping(extra_mapping2): | ||||||
|  |             assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|  |                 "SiluAndMul", | ||||||
|  |                 "SiluAndMulStringDevice", | ||||||
|  |                 "TestKernel", | ||||||
|  |             } | ||||||
|  |             assert ( | ||||||
|  |                 _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|  |                 == "kernels-community/non-existing" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|  |             "SiluAndMul", | ||||||
|  |             "SiluAndMulStringDevice", | ||||||
|  |             "TestKernel", | ||||||
|  |         } | ||||||
|  |         assert ( | ||||||
|  |             _KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id | ||||||
|  |             == "kernels-community/activation" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     assert set(_KERNEL_MAPPING.get().keys()) == { | ||||||
|  |         "SiluAndMul", | ||||||
|  |         "SiluAndMulStringDevice", | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_validate_kernel_layer(): | ||||||
|  |     class BadLayer(nn.Module): | ||||||
|  |         def __init__(self, *args, **kwargs): | ||||||
|  |             super().__init__(*args, **kwargs) | ||||||
|  |             self.foo = 42 | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError, match="not override"): | ||||||
|  |         _validate_layer(cls=BadLayer, check_cls=SiluAndMul) | ||||||
|  |  | ||||||
|  |     class BadLayer2(nn.Module): | ||||||
|  |         foo: int = 42 | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError, match="not contain additional members"): | ||||||
|  |         _validate_layer(cls=BadLayer2, check_cls=SiluAndMul) | ||||||
|  |  | ||||||
|  |     class BadLayer3(nn.Module): | ||||||
|  |         def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ... | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError, match="different number of arguments"): | ||||||
|  |         _validate_layer(cls=BadLayer3, check_cls=SiluAndMul) | ||||||
|  |  | ||||||
|  |     class BadLayer4(nn.Module): | ||||||
|  |         def forward(self, *, x: torch.Tensor) -> torch.Tensor: ... | ||||||
|  |  | ||||||
|  |     with pytest.raises(TypeError, match="different kind of arguments"): | ||||||
|  |         _validate_layer(cls=BadLayer4, check_cls=SiluAndMul) | ||||||
		Reference in New Issue
	
	Block a user
	