mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-24 15:44:32 +08:00
Compare commits
26 Commits
v0.8.1
...
stateful-l
| Author | SHA1 | Date | |
|---|---|---|---|
| a988871e9e | |||
| 055a953552 | |||
| 692d5ad458 | |||
| 2139df57f4 | |||
| 8f9a77bb6a | |||
| 6c00194680 | |||
| d6b51eefb7 | |||
| d383fdd4b4 | |||
| 07e5e8481a | |||
| 88f55d4728 | |||
| e801ebf332 | |||
| 0ae07f05fc | |||
| 7611021100 | |||
| 767e7ccf13 | |||
| 1caa4c1393 | |||
| da701bf58a | |||
| 703664ed31 | |||
| a8a6564fa7 | |||
| c89e0fa9b9 | |||
| 176a601178 | |||
| cfa0c76ddc | |||
| bcc29915f9 | |||
| 6fbff7a9cb | |||
| f7490bd0a9 | |||
| 8069e3bf0c | |||
| c540d1e1d6 |
17
.github/workflows/build_documentation.yaml
vendored
Normal file
17
.github/workflows/build_documentation.yaml
vendored
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
name: Build documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- doc-builder*
|
||||||
|
- v*-release
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||||
|
with:
|
||||||
|
commit_sha: ${{ github.sha }}
|
||||||
|
package: kernels
|
||||||
|
secrets:
|
||||||
|
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||||
15
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
15
.github/workflows/build_pr_documentation.yaml
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
name: Build PR Documentation
|
||||||
|
|
||||||
|
on: pull_request
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||||
|
with:
|
||||||
|
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||||
|
pr_number: ${{ github.event.number }}
|
||||||
|
package: kernels
|
||||||
21
.github/workflows/lint.yml
vendored
21
.github/workflows/lint.yml
vendored
@ -8,3 +8,24 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Run ruff
|
- name: Run ruff
|
||||||
uses: astral-sh/ruff-action@v3
|
uses: astral-sh/ruff-action@v3
|
||||||
|
|
||||||
|
black:
|
||||||
|
name: Run black check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
UV_PYTHON_PREFERENCE: only-managed
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv and set the python version
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: 3.12
|
||||||
|
|
||||||
|
- name: Install black
|
||||||
|
run: uv pip install black
|
||||||
|
|
||||||
|
- name: Check formatting
|
||||||
|
run: |
|
||||||
|
uv run black --check src
|
||||||
|
uv run black --check tests
|
||||||
|
|||||||
5
.github/workflows/test.yml
vendored
5
.github/workflows/test.yml
vendored
@ -51,7 +51,10 @@ jobs:
|
|||||||
run: uv run mypy src/kernels
|
run: uv run mypy src/kernels
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: uv run pytest tests
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
run: |
|
||||||
|
uv run pytest tests
|
||||||
|
|
||||||
- name: Check kernel conversion
|
- name: Check kernel conversion
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
16
.github/workflows/upload_pr_documentation.yaml
vendored
Normal file
16
.github/workflows/upload_pr_documentation.yaml
vendored
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
name: Upload PR Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_run:
|
||||||
|
workflows: ["Build PR Documentation"]
|
||||||
|
types:
|
||||||
|
- completed
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||||
|
with:
|
||||||
|
package_name: kernels
|
||||||
|
secrets:
|
||||||
|
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||||
|
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
||||||
14
README.md
14
README.md
@ -56,10 +56,12 @@ the Hub.
|
|||||||
|
|
||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
- [Using layers](docs/layers.md)
|
- [Introduction](docs/source/index.md)
|
||||||
- [Locking kernel/layer versions](docs/locking.md)
|
- [Installation](docs/source/installation.md)
|
||||||
- [Environment variables](docs/env.md)
|
- [Basic usage](docs/source/basic-usage.md)
|
||||||
- [Using kernels in a Docker container](docs/docker.md)
|
- [Using layers](docs/source/layers.md)
|
||||||
- [Kernel requirements](docs/kernel-requirements.md)
|
- [Locking kernel/layer versions](docs/source/locking.md)
|
||||||
- [Frequently Asked Questions](docs/faq.md)
|
- [Environment variables](docs/source/env.md)
|
||||||
|
- [Kernel requirements](docs/source/kernel-requirements.md)
|
||||||
|
- [Frequently Asked Questions](docs/source/faq.md)
|
||||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
|||||||
@ -1,8 +0,0 @@
|
|||||||
# Using kernels in a Docker container
|
|
||||||
|
|
||||||
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
|
||||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
|
||||||
```
|
|
||||||
30
docs/source/_toctree.yml
Normal file
30
docs/source/_toctree.yml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
- sections:
|
||||||
|
- local: index
|
||||||
|
title: Introduction
|
||||||
|
- local: installation
|
||||||
|
title: Installation
|
||||||
|
title: Getting started
|
||||||
|
- sections:
|
||||||
|
- local: basic-usage
|
||||||
|
title: Basic Usage
|
||||||
|
- local: layers
|
||||||
|
title: Using Layers
|
||||||
|
- local: locking
|
||||||
|
title: Locking Kernel Versions
|
||||||
|
- local: env
|
||||||
|
title: Environment Variables
|
||||||
|
- local: faq
|
||||||
|
title: FAQ
|
||||||
|
title: Usage Guide
|
||||||
|
- sections:
|
||||||
|
- local: api/kernels
|
||||||
|
title: Kernels
|
||||||
|
- local: api/layers
|
||||||
|
title: Layers
|
||||||
|
- local: cli
|
||||||
|
title: Kernels CLI
|
||||||
|
title: API Reference
|
||||||
|
- sections:
|
||||||
|
- local: kernel-requirements
|
||||||
|
title: Kernel Requirements
|
||||||
|
title: Developer Guide
|
||||||
21
docs/source/api/kernels.md
Normal file
21
docs/source/api/kernels.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# Kernels API Reference
|
||||||
|
|
||||||
|
## Main Functions
|
||||||
|
|
||||||
|
### get_kernel
|
||||||
|
|
||||||
|
[[autodoc]] kernels.get_kernel
|
||||||
|
|
||||||
|
### has_kernel
|
||||||
|
|
||||||
|
[[autodoc]] kernels.has_kernel
|
||||||
|
|
||||||
|
## Loading locked kernels
|
||||||
|
|
||||||
|
### load_kernel
|
||||||
|
|
||||||
|
[[autodoc]] kernels.load_kernel
|
||||||
|
|
||||||
|
### get_locked_kernel
|
||||||
|
|
||||||
|
[[autodoc]] kernels.get_locked_kernel
|
||||||
41
docs/source/api/layers.md
Normal file
41
docs/source/api/layers.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Layers API Reference
|
||||||
|
|
||||||
|
## Making layers kernel-aware
|
||||||
|
|
||||||
|
### use_kernel_forward_from_hub
|
||||||
|
|
||||||
|
[[autodoc]] kernels.use_kernel_forward_from_hub
|
||||||
|
|
||||||
|
### replace_kernel_forward_from_hub
|
||||||
|
|
||||||
|
[[autodoc]] kernels.replace_kernel_forward_from_hub
|
||||||
|
|
||||||
|
## Registering kernel mappings
|
||||||
|
|
||||||
|
### use_kernel_mapping
|
||||||
|
|
||||||
|
[[autodoc]] kernels.use_kernel_mapping
|
||||||
|
|
||||||
|
### register_kernel_mapping
|
||||||
|
|
||||||
|
[[autodoc]] kernels.register_kernel_mapping
|
||||||
|
|
||||||
|
## Kernelizing a model
|
||||||
|
|
||||||
|
### kernelize
|
||||||
|
|
||||||
|
[[autodoc]] kernels.kernelize
|
||||||
|
|
||||||
|
## Classes
|
||||||
|
|
||||||
|
### Device
|
||||||
|
|
||||||
|
[[autodoc]] kernels.Device
|
||||||
|
|
||||||
|
### Mode
|
||||||
|
|
||||||
|
[[autodoc]] kernels.Mode
|
||||||
|
|
||||||
|
### LayerRepository
|
||||||
|
|
||||||
|
[[autodoc]] kernels.LayerRepository
|
||||||
50
docs/source/basic-usage.md
Normal file
50
docs/source/basic-usage.md
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# Basic Usage
|
||||||
|
|
||||||
|
## Loading Kernels
|
||||||
|
|
||||||
|
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
# Download optimized kernels from the Hugging Face hub
|
||||||
|
activation = get_kernel("kernels-community/activation")
|
||||||
|
|
||||||
|
# Create a random tensor
|
||||||
|
x = torch.randn((10, 10), dtype=torch.float16, device="cuda")
|
||||||
|
|
||||||
|
# Run the kernel
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
activation.gelu_fast(y, x)
|
||||||
|
|
||||||
|
print(y)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using version bounds
|
||||||
|
|
||||||
|
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||||
|
You can specify which version to download using Python version specifiers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0")
|
||||||
|
```
|
||||||
|
|
||||||
|
This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It
|
||||||
|
is strongly recommended to specify a version bound, since a kernel author
|
||||||
|
might push incompatible changes to the `main` branch.
|
||||||
|
|
||||||
|
## Checking Kernel Availability
|
||||||
|
|
||||||
|
You can check if a specific kernel is available for your environment:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import has_kernel
|
||||||
|
|
||||||
|
# Check if kernel is available for current environment
|
||||||
|
is_available = has_kernel("kernels-community/activation")
|
||||||
|
print(f"Kernel available: {is_available}")
|
||||||
|
```
|
||||||
41
docs/source/cli.md
Normal file
41
docs/source/cli.md
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Kernels CLI Reference
|
||||||
|
|
||||||
|
## Main Functions
|
||||||
|
|
||||||
|
### kernels to-wheel
|
||||||
|
|
||||||
|
We strongly recommend downloading kernels from the Hub using the `kernels`
|
||||||
|
package, since this comes with large [benefits](index.md) over using Python
|
||||||
|
wheels. That said, some projects may require deployment of kernels as
|
||||||
|
wheels. The `kernels` utility provides a simple solution to this. You can
|
||||||
|
convert any Hub kernel into a set of wheels with the `to-wheel` command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ kernels to-wheel drbh/img2grey 1.1.2
|
||||||
|
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu124cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch27cu128cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu126cxx98-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||||
|
☸ img2grey-1.1.2+torch27cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu126cxx11-cp39-abi3-manylinux_2_28_aarch64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu118cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu124cxx98-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch26cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
☸ img2grey-1.1.2+torch27cu118cxx11-cp39-abi3-manylinux_2_28_x86_64.whl
|
||||||
|
```
|
||||||
|
|
||||||
|
### kernels upload
|
||||||
|
|
||||||
|
Use `kernels upload <dir_containing_build> --repo_id="hub-username/kernel"` to upload
|
||||||
|
your kernel builds to the Hub.
|
||||||
|
|
||||||
|
**Notes**:
|
||||||
|
|
||||||
|
- This will take care of creating a repository on the Hub with the `repo_id` provided.
|
||||||
|
- If a repo with the `repo_id` already exists and if it contains a `build` with the build variant
|
||||||
|
being uploaded, it will attempt to delete the files existing under it.
|
||||||
|
- Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub.
|
||||||
|
|
||||||
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
## Why is the kernelization step needed?
|
## Why is the kernelization step needed?
|
||||||
|
|
||||||
In earlier versions of `kernels`, a layer's `forward` was replaced by
|
In earlier versions of `kernels`, a layer's `forward` method was replaced
|
||||||
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
|
by `use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`.
|
||||||
new `forward` would dispatch to a kernel based on the device type,
|
The new `forward` would dispatch to a kernel based on the device type,
|
||||||
whether a model was training, etc. However, this approach was
|
whether a model was training, etc. However, this approach was
|
||||||
fundamentally incompatible with `torch.compile` since it relied
|
fundamentally incompatible with `torch.compile` since it relied
|
||||||
on data-dependent branching.
|
on data-dependent branching.
|
||||||
20
docs/source/index.md
Normal file
20
docs/source/index.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Kernels
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://github.com/user-attachments/assets/64a652f3-0cd3-4829-b3c1-df13f7933569" width="450" height="450" alt="kernel-builder logo">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
The Kernel Hub allows Python libraries and applications to load compute
|
||||||
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||||
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||||
|
packages in that they are made to be:
|
||||||
|
|
||||||
|
- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`.
|
||||||
|
- **Unique**: multiple versions of the same kernel can be loaded in the
|
||||||
|
same Python process.
|
||||||
|
- **Compatible**: kernels must support all recent versions of Python and
|
||||||
|
the different PyTorch build configurations (various CUDA versions
|
||||||
|
and C++ ABIs). Furthermore, older C library versions must be supported.
|
||||||
|
|
||||||
|
You can [search for kernels](https://huggingface.co/models?other=kernel) on
|
||||||
|
the Hub.
|
||||||
16
docs/source/installation.md
Normal file
16
docs/source/installation.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# Installation
|
||||||
|
|
||||||
|
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install kernels
|
||||||
|
```
|
||||||
|
|
||||||
|
# Using kernels in a Docker container
|
||||||
|
|
||||||
|
Build and run the reference `examples/basic.py` in a Docker container with the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||||
|
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||||
|
```
|
||||||
@ -34,6 +34,8 @@ Kernels are versioned on the Hub using Git tags. Version tags must be of
|
|||||||
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
the form `v<major>.<minor>.<patch>`. Versions are used by [locking](./locking.md)
|
||||||
to resolve the version constraints.
|
to resolve the version constraints.
|
||||||
|
|
||||||
|
We recommend using [semver](https://semver.org/) to version kernels.
|
||||||
|
|
||||||
## Native Python module
|
## Native Python module
|
||||||
|
|
||||||
Kernels will typically contain a native Python module with precompiled
|
Kernels will typically contain a native Python module with precompiled
|
||||||
@ -50,13 +52,12 @@ have dynamic library dependencies outside:
|
|||||||
for compatibility with Python 3.9 and later.
|
for compatibility with Python 3.9 and later.
|
||||||
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
- Compatible with [`manylinux_2_28`](https://github.com/pypa/manylinux?tab=readme-ov-file#manylinux_2_28-almalinux-8-based).
|
||||||
This means that the extension **must not** use symbols versions higher than:
|
This means that the extension **must not** use symbols versions higher than:
|
||||||
|
|
||||||
- GLIBC 2.28
|
- GLIBC 2.28
|
||||||
- GLIBCXX 3.4.24
|
- GLIBCXX 3.4.24
|
||||||
- CXXABI 1.3.11
|
- CXXABI 1.3.11
|
||||||
- GCC 7.0.0
|
- GCC 7.0.0
|
||||||
|
|
||||||
These requirement can be checked with the ABI checker (see below).
|
These requirements can be checked with the ABI checker (see below).
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
@ -5,7 +5,7 @@ the Hub can replace the `forward` method of an existing layer for a certain
|
|||||||
device type. This makes it possible to provide more performant kernels for
|
device type. This makes it possible to provide more performant kernels for
|
||||||
existing layers.
|
existing layers.
|
||||||
|
|
||||||
See [Kernel requirements](kernel-requirements.md) for more information the
|
See [Kernel requirements](kernel-requirements.md) for more information on the
|
||||||
requirements of Hub layers.
|
requirements of Hub layers.
|
||||||
|
|
||||||
## Making a layer extensible with kernels from the hub
|
## Making a layer extensible with kernels from the hub
|
||||||
@ -84,12 +84,6 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
|||||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
```
|
```
|
||||||
|
|
||||||
When the `mode` argument is not specified,
|
|
||||||
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
|
|
||||||
aligns most closely with pure PyTorch layers which also support training
|
|
||||||
and `torch.compile`. However, to select the most performant kernels, it
|
|
||||||
is often good to make the mode specific as possible.
|
|
||||||
|
|
||||||
### Kernel device
|
### Kernel device
|
||||||
|
|
||||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||||
@ -107,7 +101,7 @@ model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
|||||||
|
|
||||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||||
kernel does not support backward passes or `torch.compile` respectively,
|
kernel does not support backward passes or `torch.compile` respectively,
|
||||||
`kernenize` will fall back to the original, non-kernelized, layer. You
|
`kernelize` will fall back to the original, non-kernelized, layer. You
|
||||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@ -117,7 +111,7 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=
|
|||||||
|
|
||||||
This can be useful if you want to guarantee that Hub kernels are used.
|
This can be useful if you want to guarantee that Hub kernels are used.
|
||||||
|
|
||||||
### Inspecting kernels which kernels are used
|
### Inspecting which kernels are used
|
||||||
|
|
||||||
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
||||||
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
||||||
@ -135,6 +129,10 @@ kernel_layer_mapping = {
|
|||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="kernels-community/activation",
|
repo_id="kernels-community/activation",
|
||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
|
"rocm": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -153,12 +151,39 @@ used with the `use_kernel_mapping` context manager:
|
|||||||
```python
|
```python
|
||||||
with use_kernel_mapping(kernel_layer_mapping):
|
with use_kernel_mapping(kernel_layer_mapping):
|
||||||
# Use the layer for which the mapping is applied.
|
# Use the layer for which the mapping is applied.
|
||||||
model = kernelize(model)
|
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
```
|
```
|
||||||
|
|
||||||
This ensures that the mapping is not active anymore outside the
|
This ensures that the mapping is not active anymore outside the
|
||||||
`with`-scope.
|
`with`-scope.
|
||||||
|
|
||||||
|
### Using version bounds
|
||||||
|
|
||||||
|
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.
|
||||||
|
You can specify which version of the kernel to download using Python version
|
||||||
|
specifiers:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
version=">=0.0.4,<0.1.0",
|
||||||
|
),
|
||||||
|
"rocm": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
version=">=0.0.4,<0.1.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This will get the layer from latest kernel tagged `v0.0.z` where `z` is at
|
||||||
|
least 4. It is strongly recommended to specify a version bound, since a
|
||||||
|
kernel author might push incompatible changes to the `main` branch.
|
||||||
|
|
||||||
### Registering kernels for specific modes
|
### Registering kernels for specific modes
|
||||||
|
|
||||||
You might want to register two different kernels for a particular layer,
|
You might want to register two different kernels for a particular layer,
|
||||||
@ -261,7 +286,6 @@ Capabilities behave as follows:
|
|||||||
an existing kernel, the new kernel will replace the old kernel.
|
an existing kernel, the new kernel will replace the old kernel.
|
||||||
- When there are multiple kernels that support a capability, the kernel
|
- When there are multiple kernels that support a capability, the kernel
|
||||||
with the smaller capability interval will be used. E.g. given:
|
with the smaller capability interval will be used. E.g. given:
|
||||||
|
|
||||||
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
||||||
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
||||||
- `kernelize` runs on a system with capability 8.6.
|
- `kernelize` runs on a system with capability 8.6.
|
||||||
@ -270,3 +294,30 @@ Capabilities behave as follows:
|
|||||||
than 75..89. The motivation is that kernels with smaller ranges
|
than 75..89. The motivation is that kernels with smaller ranges
|
||||||
tend to be more optimized for a specific set of GPUs. **This behavior
|
tend to be more optimized for a specific set of GPUs. **This behavior
|
||||||
might still change in the future.**
|
might still change in the future.**
|
||||||
|
|
||||||
|
### Registering kernels for specific ROCm capabilities
|
||||||
|
|
||||||
|
Registering kernels for the ROCm architecture follows the exact same
|
||||||
|
pattern as CUDA kernels, using `min_capability` and `max_capability` to restrict
|
||||||
|
a kernel to a range of ROCm capabilities.
|
||||||
|
|
||||||
|
### Loading from a local repository for testing
|
||||||
|
|
||||||
|
The `LocalLayerRepository` class is provided to load a repository from
|
||||||
|
a local directory. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": LocalLayerRepository(
|
||||||
|
repo_path="/home/daniel/kernels/activation",
|
||||||
|
package_name="activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
inherit_mapping=False,
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
```
|
||||||
18
flake.lock
generated
18
flake.lock
generated
@ -58,11 +58,11 @@
|
|||||||
"nixpkgs": "nixpkgs"
|
"nixpkgs": "nixpkgs"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1750775451,
|
"lastModified": 1754038838,
|
||||||
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
||||||
"owner": "huggingface",
|
"owner": "huggingface",
|
||||||
"repo": "hf-nix",
|
"repo": "hf-nix",
|
||||||
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -73,17 +73,17 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1747820358,
|
"lastModified": 1752785354,
|
||||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
||||||
"owner": "danieldk",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
"owner": "danieldk",
|
"owner": "nixos",
|
||||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
|
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@ -26,6 +26,10 @@
|
|||||||
formatter = pkgs.nixfmt-tree;
|
formatter = pkgs.nixfmt-tree;
|
||||||
devShells = with pkgs; rec {
|
devShells = with pkgs; rec {
|
||||||
default = mkShell {
|
default = mkShell {
|
||||||
|
nativeBuildInputs = [
|
||||||
|
# For hf-doc-builder.
|
||||||
|
nodejs
|
||||||
|
];
|
||||||
buildInputs =
|
buildInputs =
|
||||||
[
|
[
|
||||||
black
|
black
|
||||||
@ -36,6 +40,7 @@
|
|||||||
++ (with python3.pkgs; [
|
++ (with python3.pkgs; [
|
||||||
docutils
|
docutils
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
|
mktestdocs
|
||||||
pytest
|
pytest
|
||||||
pytest-benchmark
|
pytest-benchmark
|
||||||
pyyaml
|
pyyaml
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "kernels"
|
name = "kernels"
|
||||||
version = "0.8.1.dev0"
|
version = "0.10.1.dev0"
|
||||||
description = "Download compute kernels"
|
description = "Download compute kernels"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||||
@ -24,16 +24,20 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"mypy >= 1.15.0",
|
"mktestdocs>=0.2.5",
|
||||||
"pytest >=8",
|
"mypy>=1.15.0",
|
||||||
|
"pytest>=8",
|
||||||
# Whatever version is compatible with pytest.
|
# Whatever version is compatible with pytest.
|
||||||
"pytest-benchmark",
|
"pytest-benchmark",
|
||||||
"torch >=2.5",
|
"torch>=2.5",
|
||||||
"types-pyyaml"
|
"types-pyyaml"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
torch = ["torch"]
|
torch = ["torch"]
|
||||||
|
docs = [
|
||||||
|
"hf-doc-builder",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
kernels = "kernels.cli:main"
|
kernels = "kernels.cli:main"
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
markers =
|
markers =
|
||||||
|
cuda_only: marks tests that should only hosts with CUDA GPUs
|
||||||
|
rocm_only: marks tests that should only run on hosts with ROCm GPUs
|
||||||
darwin_only: marks tests that should only run on macOS
|
darwin_only: marks tests that should only run on macOS
|
||||||
linux_only: marks tests that should only run on Linux
|
xpu_only: marks tests that should only run on hosts with Intel XPUs
|
||||||
|
token: enable tests that require a write token
|
||||||
|
|||||||
@ -1,7 +1,13 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
|
||||||
|
__version__ = importlib.metadata.version("kernels")
|
||||||
|
|
||||||
from kernels.layer import (
|
from kernels.layer import (
|
||||||
CUDAProperties,
|
CUDAProperties,
|
||||||
Device,
|
Device,
|
||||||
LayerRepository,
|
LayerRepository,
|
||||||
|
LocalLayerRepository,
|
||||||
|
LockedLayerRepository,
|
||||||
Mode,
|
Mode,
|
||||||
kernelize,
|
kernelize,
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
@ -19,9 +25,12 @@ from kernels.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
"CUDAProperties",
|
"CUDAProperties",
|
||||||
"Device",
|
"Device",
|
||||||
"LayerRepository",
|
"LayerRepository",
|
||||||
|
"LocalLayerRepository",
|
||||||
|
"LockedLayerRepository",
|
||||||
"Mode",
|
"Mode",
|
||||||
"get_kernel",
|
"get_kernel",
|
||||||
"get_local_kernel",
|
"get_local_kernel",
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import create_repo, upload_folder
|
||||||
|
|
||||||
from kernels.compat import tomllib
|
from kernels.compat import tomllib
|
||||||
from kernels.lockfile import KernelLock, get_kernel_locks
|
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||||
from kernels.utils import install_kernel, install_kernel_all_variants
|
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||||
@ -31,6 +33,24 @@ def main():
|
|||||||
)
|
)
|
||||||
download_parser.set_defaults(func=download_kernels)
|
download_parser.set_defaults(func=download_kernels)
|
||||||
|
|
||||||
|
upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub")
|
||||||
|
upload_parser.add_argument(
|
||||||
|
"kernel_dir",
|
||||||
|
type=Path,
|
||||||
|
help="Directory of the kernel build",
|
||||||
|
)
|
||||||
|
upload_parser.add_argument(
|
||||||
|
"--repo_id",
|
||||||
|
type=str,
|
||||||
|
help="Repository ID to use to upload to the Hugging Face Hub",
|
||||||
|
)
|
||||||
|
upload_parser.add_argument(
|
||||||
|
"--private",
|
||||||
|
action="store_true",
|
||||||
|
help="If the repository should be private.",
|
||||||
|
)
|
||||||
|
upload_parser.set_defaults(func=upload_kernels)
|
||||||
|
|
||||||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
||||||
lock_parser.add_argument(
|
lock_parser.add_argument(
|
||||||
"project_dir",
|
"project_dir",
|
||||||
@ -153,6 +173,33 @@ def lock_kernels(args):
|
|||||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_kernels(args):
|
||||||
|
kernel_dir = Path(args.kernel_dir).resolve()
|
||||||
|
build_dir = kernel_dir / "build"
|
||||||
|
if not kernel_dir.is_dir():
|
||||||
|
raise ValueError(f"{kernel_dir} is not a directory")
|
||||||
|
if not build_dir.is_dir():
|
||||||
|
raise ValueError("Couldn't find `build` directory inside `kernel_dir`")
|
||||||
|
|
||||||
|
repo_id = create_repo(
|
||||||
|
repo_id=args.repo_id, private=args.private, exist_ok=True
|
||||||
|
).repo_id
|
||||||
|
|
||||||
|
delete_patterns: set[str] = set()
|
||||||
|
for build_variant in build_dir.iterdir():
|
||||||
|
if build_variant.is_dir():
|
||||||
|
delete_patterns.add(f"{build_variant.name}/**")
|
||||||
|
|
||||||
|
upload_folder(
|
||||||
|
repo_id=repo_id,
|
||||||
|
folder_path=build_dir,
|
||||||
|
path_in_repo="build",
|
||||||
|
delete_patterns=list(delete_patterns),
|
||||||
|
commit_message="Build uploaded using `kernels`.",
|
||||||
|
)
|
||||||
|
print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.")
|
||||||
|
|
||||||
|
|
||||||
class _JSONEncoder(json.JSONEncoder):
|
class _JSONEncoder(json.JSONEncoder):
|
||||||
def default(self, o):
|
def default(self, o):
|
||||||
if dataclasses.is_dataclass(o):
|
if dataclasses.is_dataclass(o):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -46,8 +46,9 @@ def build_variant() -> str:
|
|||||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
compute_framework = "metal"
|
compute_framework = "metal"
|
||||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
elif torch.version.xpu is not None:
|
||||||
compute_framework = "xpu"
|
version = torch.version.xpu
|
||||||
|
compute_framework = f"xpu{version[0:4]}{version[5:6]}"
|
||||||
else:
|
else:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||||
@ -98,7 +99,20 @@ def install_kernel(
|
|||||||
"""
|
"""
|
||||||
Download a kernel for the current environment to the cache.
|
Download a kernel for the current environment to the cache.
|
||||||
|
|
||||||
The output path is validated againt `hash` when set.
|
The output path is validated against the hashes in `variant_locks` when provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
The Hub repository containing the kernel.
|
||||||
|
revision (`str`):
|
||||||
|
The specific revision (branch, tag, or commit) to download.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only use local files and not download from the Hub.
|
||||||
|
variant_locks (`Dict[str, VariantLock]`, *optional*):
|
||||||
|
Optional dictionary of variant locks for validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
|
||||||
"""
|
"""
|
||||||
package_name = package_name_from_repo_id(repo_id)
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
variant = build_variant()
|
variant = build_variant()
|
||||||
@ -190,23 +204,31 @@ def get_kernel(
|
|||||||
) -> ModuleType:
|
) -> ModuleType:
|
||||||
"""
|
"""
|
||||||
Load a kernel from the kernel hub.
|
Load a kernel from the kernel hub.
|
||||||
This function downloads a kernel to the local Hugging Face Hub cache
|
|
||||||
directory (if it was not downloaded before) and then loads the kernel.
|
This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before)
|
||||||
|
and then loads the kernel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id (`str`): The Hub repository containing the kernel.
|
repo_id (`str`):
|
||||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
The Hub repository containing the kernel.
|
||||||
revision (branch, tag, or commit) to download.
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
Cannot be used together with `version`.
|
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
|
||||||
version (`str`, *optional*): The kernel version to download. This
|
version (`str`, *optional*):
|
||||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||||
Cannot be used together with `revision`.
|
Cannot be used together with `revision`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`ModuleType`: The imported kernel module.
|
`ModuleType`: The imported kernel module.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
|
import torch
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
kernel = get_kernel("username/my-kernel")
|
|
||||||
result = kernel.kernel_function(input_data)
|
activation = get_kernel("kernels-community/activation")
|
||||||
|
x = torch.randn(10, 20, device="cuda")
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
result = activation.silu_and_mul(out, x)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
revision = select_revision_or_version(repo_id, revision, version)
|
revision = select_revision_or_version(repo_id, revision, version)
|
||||||
@ -217,28 +239,53 @@ def get_kernel(
|
|||||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||||
"""
|
"""
|
||||||
Import a kernel from a local kernel repository path.
|
Import a kernel from a local kernel repository path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_path (`Path`):
|
||||||
|
The local path to the kernel repository.
|
||||||
|
package_name (`str`):
|
||||||
|
The name of the package to import from the repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ModuleType`: The imported kernel module.
|
||||||
"""
|
"""
|
||||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
variant = build_variant()
|
||||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
universal_variant = universal_build_variant()
|
||||||
|
|
||||||
|
# Presume we were given the top level path of the kernel repository.
|
||||||
|
for base_path in [repo_path, repo_path / "build"]:
|
||||||
|
# Prefer the universal variant if it exists.
|
||||||
|
for v in [universal_variant, variant]:
|
||||||
|
package_path = base_path / v / package_name / "__init__.py"
|
||||||
|
if package_path.exists():
|
||||||
|
return import_from_path(package_name, package_path)
|
||||||
|
|
||||||
|
# If we didn't find the package in the repo we may have a explicit
|
||||||
|
# package path.
|
||||||
|
package_path = repo_path / package_name / "__init__.py"
|
||||||
|
if package_path.exists():
|
||||||
|
return import_from_path(package_name, package_path)
|
||||||
|
|
||||||
|
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
|
||||||
|
|
||||||
|
|
||||||
def has_kernel(
|
def has_kernel(
|
||||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether a kernel build exists for the current environment
|
Check whether a kernel build exists for the current environment (Torch version and compute framework).
|
||||||
(Torch version and compute framework).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id (`str`): The Hub repository containing the kernel.
|
repo_id (`str`):
|
||||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
The Hub repository containing the kernel.
|
||||||
revision (branch, tag, or commit) to download.
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
Cannot be used together with `version`.
|
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.
|
||||||
version (`str`, *optional*): The kernel version to download. This
|
version (`str`, *optional*):
|
||||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||||
Cannot be used together with `revision`.
|
Cannot be used together with `revision`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`bool`: `true` if a kernel is avaialble for the current environment.
|
`bool`: `True` if a kernel is available for the current environment.
|
||||||
"""
|
"""
|
||||||
revision = select_revision_or_version(repo_id, revision, version)
|
revision = select_revision_or_version(repo_id, revision, version)
|
||||||
|
|
||||||
@ -264,8 +311,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
|||||||
"""
|
"""
|
||||||
Get a pre-downloaded, locked kernel.
|
Get a pre-downloaded, locked kernel.
|
||||||
|
|
||||||
If `lockfile` is not specified, the lockfile will be loaded from the
|
If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata.
|
||||||
caller's package metadata.
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
The Hub repository containing the kernel.
|
||||||
|
lockfile (`Path`, *optional*):
|
||||||
|
Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ModuleType`: The imported kernel module.
|
||||||
"""
|
"""
|
||||||
if lockfile is None:
|
if lockfile is None:
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
@ -310,7 +365,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
|
|||||||
|
|
||||||
|
|
||||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||||
"""Get a kernel using a lock file."""
|
"""
|
||||||
|
Get a kernel using a lock file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`):
|
||||||
|
The Hub repository containing the kernel.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only use local files and not download from the Hub.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`ModuleType`: The imported kernel module.
|
||||||
|
"""
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
|
|
||||||
if locked_sha is None:
|
if locked_sha is None:
|
||||||
|
|||||||
@ -1,10 +1,41 @@
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
has_cuda = (
|
||||||
|
hasattr(torch.version, "cuda")
|
||||||
|
and torch.version.cuda is not None
|
||||||
|
and torch.cuda.device_count() > 0
|
||||||
|
)
|
||||||
|
has_rocm = (
|
||||||
|
hasattr(torch.version, "hip")
|
||||||
|
and torch.version.hip is not None
|
||||||
|
and torch.cuda.device_count() > 0
|
||||||
|
)
|
||||||
|
has_xpu = (
|
||||||
|
hasattr(torch.version, "xpu")
|
||||||
|
and torch.version.xpu is not None
|
||||||
|
and torch.xpu.device_count() > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--token",
|
||||||
|
action="store_true",
|
||||||
|
help="run tests that require a token with write permissions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item):
|
||||||
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
|
if "cuda_only" in item.keywords and not has_cuda:
|
||||||
pytest.skip("skipping Linux-only test on non-Linux platform")
|
pytest.skip("skipping CUDA-only test on host without CUDA")
|
||||||
|
if "rocm_only" in item.keywords and not has_rocm:
|
||||||
|
pytest.skip("skipping ROCm-only test on host without ROCm")
|
||||||
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
|
||||||
pytest.skip("skipping macOS-only test on non-macOS platform")
|
pytest.skip("skipping macOS-only test on non-macOS platform")
|
||||||
|
if "xpu_only" in item.keywords and not has_xpu:
|
||||||
|
pytest.skip("skipping XPU-only test on host without XPU")
|
||||||
|
if "token" in item.keywords and not item.config.getoption("--token"):
|
||||||
|
pytest.skip("need --token option to run this test")
|
||||||
|
|||||||
@ -10,10 +10,16 @@ def kernel():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def local_kernel():
|
def local_kernel_path():
|
||||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||||
# is the kernel repository path.
|
# is the kernel repository path.
|
||||||
|
return package_name, path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def local_kernel(local_kernel_path):
|
||||||
|
package_name, path = local_kernel_path
|
||||||
return get_local_kernel(path.parent.parent, package_name)
|
return get_local_kernel(path.parent.parent, package_name)
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +40,7 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_fast(kernel, device):
|
def test_gelu_fast(kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -50,7 +56,7 @@ def test_gelu_fast(kernel, device):
|
|||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_local_kernel(local_kernel, device):
|
def test_local_kernel(local_kernel, device):
|
||||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
@ -66,6 +72,39 @@ def test_local_kernel(local_kernel, device):
|
|||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
def test_local_kernel_path_types(local_kernel_path, device):
|
||||||
|
package_name, path = local_kernel_path
|
||||||
|
|
||||||
|
# Top-level repo path
|
||||||
|
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071
|
||||||
|
kernel = get_local_kernel(path.parent.parent, package_name)
|
||||||
|
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
|
||||||
|
kernel.gelu_fast(y, x)
|
||||||
|
expected = torch.tensor(
|
||||||
|
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
# Build directory path
|
||||||
|
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build
|
||||||
|
kernel = get_local_kernel(path.parent.parent / "build", package_name)
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
kernel.gelu_fast(y, x)
|
||||||
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
# Explicit package path
|
||||||
|
# ie: /home/ubuntu/.cache/huggingface/hub/models--kernels-community--activation/snapshots/2fafa6a3a38ccb57a1a98419047cf7816ecbc071/build/torch28-cxx11-cu128-x86_64-linux
|
||||||
|
kernel = get_local_kernel(path, package_name)
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
kernel.gelu_fast(y, x)
|
||||||
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.darwin_only
|
@pytest.mark.darwin_only
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||||
def test_relu_metal(metal_kernel, dtype):
|
def test_relu_metal(metal_kernel, dtype):
|
||||||
@ -74,7 +113,7 @@ def test_relu_metal(metal_kernel, dtype):
|
|||||||
assert torch.allclose(y, torch.relu(x))
|
assert torch.allclose(y, torch.relu(x))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"kernel_exists",
|
"kernel_exists",
|
||||||
[
|
[
|
||||||
@ -110,7 +149,7 @@ def test_version():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_universal_kernel(universal_kernel):
|
def test_universal_kernel(universal_kernel):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||||
|
|||||||
@ -16,21 +16,21 @@ def device():
|
|||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_small(kernel, device, benchmark):
|
def test_gelu_small(kernel, device, benchmark):
|
||||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_medium(kernel, device, benchmark):
|
def test_gelu_medium(kernel, device, benchmark):
|
||||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
benchmark(kernel.gelu_fast, y, x)
|
benchmark(kernel.gelu_fast, y, x)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_gelu_large(kernel, device, benchmark):
|
def test_gelu_large(kernel, device, benchmark):
|
||||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
|
|||||||
49
tests/test_doctest.py
Normal file
49
tests/test_doctest.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import inspect
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from mktestdocs import check_docstring, get_codeblock_members
|
||||||
|
|
||||||
|
import kernels
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_functions():
|
||||||
|
function_list = inspect.getmembers(kernels, inspect.isfunction)
|
||||||
|
return [func for _, func in function_list]
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_classes():
|
||||||
|
class_list = inspect.getmembers(kernels, inspect.isclass)
|
||||||
|
return [cls for _, cls in class_list]
|
||||||
|
|
||||||
|
|
||||||
|
def all_public_class_members():
|
||||||
|
members = get_codeblock_members(*all_public_classes())
|
||||||
|
return members
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"func",
|
||||||
|
all_public_functions(),
|
||||||
|
ids=lambda d: d.__name__,
|
||||||
|
)
|
||||||
|
def test_func_docstring(func):
|
||||||
|
check_docstring(obj=func)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cls",
|
||||||
|
all_public_classes(),
|
||||||
|
ids=lambda d: d.__name__,
|
||||||
|
)
|
||||||
|
def test_class_docstring(cls):
|
||||||
|
check_docstring(obj=cls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"member", all_public_class_members(), ids=lambda d: d.__qualname__
|
||||||
|
)
|
||||||
|
def test_member_docstring(member):
|
||||||
|
check_docstring(member)
|
||||||
@ -27,7 +27,7 @@ def test_download_all_hash_validation():
|
|||||||
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_load_locked():
|
def test_load_locked():
|
||||||
project_dir = Path(__file__).parent / "kernel_locking"
|
project_dir = Path(__file__).parent / "kernel_locking"
|
||||||
# Also validates that hashing works correctly.
|
# Also validates that hashing works correctly.
|
||||||
|
|||||||
88
tests/test_kernel_upload.py
Normal file
88
tests/test_kernel_upload.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from huggingface_hub import model_info
|
||||||
|
|
||||||
|
from kernels.cli import upload_kernels
|
||||||
|
|
||||||
|
REPO_ID = "kernels-test/kernels-upload-test"
|
||||||
|
|
||||||
|
PY_CONTENT = """\
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Hello from torch-universal!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadArgs:
|
||||||
|
kernel_dir: None
|
||||||
|
repo_id: None
|
||||||
|
private: False
|
||||||
|
|
||||||
|
|
||||||
|
def next_filename(path: Path) -> Path:
|
||||||
|
"""
|
||||||
|
Given a path like foo_2050.py, return foo_2051.py.
|
||||||
|
"""
|
||||||
|
m = re.match(r"^(.*?)(\d+)(\.py)$", path.name)
|
||||||
|
if not m:
|
||||||
|
raise ValueError(
|
||||||
|
f"Filename {path.name!r} does not match pattern <prefix>_<number>.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix, number, suffix = m.groups()
|
||||||
|
new_number = str(int(number) + 1).zfill(len(number))
|
||||||
|
return path.with_name(f"{prefix}{new_number}{suffix}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_filename_to_change(repo_filenames):
|
||||||
|
for f in repo_filenames:
|
||||||
|
if "foo" in f and f.endswith(".py"):
|
||||||
|
filename_to_change = os.path.basename(f)
|
||||||
|
break
|
||||||
|
assert filename_to_change
|
||||||
|
return filename_to_change
|
||||||
|
|
||||||
|
|
||||||
|
def get_filenames_from_a_repo(repo_id: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
repo_info = model_info(repo_id=repo_id, files_metadata=True)
|
||||||
|
repo_siblings = repo_info.siblings
|
||||||
|
if repo_siblings is not None:
|
||||||
|
return [f.rfilename for f in repo_siblings]
|
||||||
|
else:
|
||||||
|
raise ValueError("No repo siblings found.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error connecting to the Hub: {e}.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.token
|
||||||
|
def test_kernel_upload_deletes_as_expected():
|
||||||
|
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||||
|
filename_to_change = get_filename_to_change(repo_filenames)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = f"{tmpdir}/build/torch-universal/upload_test"
|
||||||
|
build_dir = Path(path)
|
||||||
|
build_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
changed_filename = next_filename(Path(filename_to_change))
|
||||||
|
script_path = build_dir / changed_filename
|
||||||
|
script_path.write_text(PY_CONTENT)
|
||||||
|
upload_kernels(UploadArgs(tmpdir, REPO_ID, False))
|
||||||
|
|
||||||
|
repo_filenames = get_filenames_from_a_repo(REPO_ID)
|
||||||
|
assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"
|
||||||
|
assert not any(
|
||||||
|
str(filename_to_change) in k for k in repo_filenames
|
||||||
|
), f"{repo_filenames=}"
|
||||||
@ -5,21 +5,24 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
from kernels import (
|
from kernels import (
|
||||||
|
CUDAProperties,
|
||||||
Device,
|
Device,
|
||||||
LayerRepository,
|
LayerRepository,
|
||||||
|
LocalLayerRepository,
|
||||||
Mode,
|
Mode,
|
||||||
kernelize,
|
kernelize,
|
||||||
register_kernel_mapping,
|
register_kernel_mapping,
|
||||||
use_kernel_forward_from_hub,
|
use_kernel_forward_from_hub,
|
||||||
|
use_kernel_mapping,
|
||||||
)
|
)
|
||||||
from kernels.layer import (
|
from kernels.layer import (
|
||||||
_KERNEL_MAPPING,
|
_KERNEL_MAPPING,
|
||||||
CUDAProperties,
|
|
||||||
_validate_layer,
|
_validate_layer,
|
||||||
use_kernel_mapping,
|
|
||||||
)
|
)
|
||||||
|
from kernels.utils import install_kernel
|
||||||
|
|
||||||
kernel_layer_mapping = {
|
kernel_layer_mapping = {
|
||||||
"SiluAndMul": {
|
"SiluAndMul": {
|
||||||
@ -32,7 +35,11 @@ kernel_layer_mapping = {
|
|||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
repo_id="kernels-test/op-without-fake-test",
|
repo_id="kernels-test/op-without-fake-test",
|
||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
)
|
),
|
||||||
|
"rocm": LayerRepository(
|
||||||
|
repo_id="kernels-test/op-without-fake-test",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"SiluAndMulStringDevice": {
|
"SiluAndMulStringDevice": {
|
||||||
"cuda": LayerRepository(
|
"cuda": LayerRepository(
|
||||||
@ -40,11 +47,37 @@ kernel_layer_mapping = {
|
|||||||
layer_name="SiluAndMul",
|
layer_name="SiluAndMul",
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
"LigerRMSNorm": {
|
||||||
|
"xpu": LayerRepository(
|
||||||
|
repo_id="kernels-community/liger_kernels",
|
||||||
|
layer_name="LigerRMSNorm", # Triton
|
||||||
|
)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
register_kernel_mapping(kernel_layer_mapping)
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
# Used to check that we called hub kernel.
|
||||||
|
self.n_calls = 0
|
||||||
|
self.weight = nn.Parameter(weight)
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
self.n_calls += 1
|
||||||
|
var = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
|
||||||
|
return x_norm * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("LigerRMSNorm")
|
||||||
|
class RMSNormWithKernel(RMSNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(nn.Module):
|
class SiluAndMul(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -84,6 +117,16 @@ class TorchLinearWithCounter(nn.Linear):
|
|||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def device():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
return "xpu"
|
||||||
|
|
||||||
|
pytest.skip("No CUDA or XPU")
|
||||||
|
|
||||||
|
|
||||||
def test_arg_kinds():
|
def test_arg_kinds():
|
||||||
@use_kernel_forward_from_hub("ArgKind")
|
@use_kernel_forward_from_hub("ArgKind")
|
||||||
class ArgKind(nn.Module):
|
class ArgKind(nn.Module):
|
||||||
@ -102,29 +145,99 @@ def test_arg_kinds():
|
|||||||
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
def test_hub_forward(cls):
|
||||||
def test_hub_forward(cls, device):
|
|
||||||
torch.random.manual_seed(0)
|
torch.random.manual_seed(0)
|
||||||
|
|
||||||
silu_and_mul = SiluAndMul()
|
silu_and_mul = SiluAndMul()
|
||||||
X = torch.randn((32, 64), device=device)
|
X = torch.randn((32, 64), device="cuda")
|
||||||
Y = silu_and_mul(X)
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
|
||||||
Y_kernel = silu_and_mul_with_kernel(X)
|
Y_kernel = silu_and_mul_with_kernel(X)
|
||||||
|
|
||||||
torch.testing.assert_close(Y_kernel, Y)
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
|
|
||||||
assert silu_and_mul.n_calls == 1
|
assert silu_and_mul.n_calls == 1
|
||||||
if device == "cuda":
|
assert silu_and_mul_with_kernel.n_calls == 0
|
||||||
assert silu_and_mul_with_kernel.n_calls == 0
|
|
||||||
else:
|
|
||||||
assert silu_and_mul_with_kernel.n_calls == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.rocm_only
|
||||||
|
def test_hub_forward_rocm():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
silu_and_mul = SiluAndMul()
|
||||||
|
X = torch.randn((32, 64))
|
||||||
|
Y = silu_and_mul(X)
|
||||||
|
|
||||||
|
silu_and_mul_with_kernel = kernelize(
|
||||||
|
SiluAndMulNoCompileKernel(), device="rocm", mode=Mode.INFERENCE
|
||||||
|
)
|
||||||
|
Y_kernel = silu_and_mul_with_kernel(X)
|
||||||
|
|
||||||
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
|
|
||||||
|
assert silu_and_mul.n_calls == 1
|
||||||
|
# Should use kernel (n_calls == 0) if ROCm kernel is available, otherwise fallback (n_calls == 1)
|
||||||
|
# The exact behavior depends on whether the test kernel exists for ROCm
|
||||||
|
assert silu_and_mul_with_kernel.n_calls in [0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xpu_only
|
||||||
|
def test_hub_forward_xpu():
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
hidden_size = 1024
|
||||||
|
weight = torch.ones(hidden_size, device="xpu")
|
||||||
|
rms_norm = RMSNorm(weight).to("xpu")
|
||||||
|
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
|
||||||
|
Y = rms_norm(X)
|
||||||
|
|
||||||
|
rms_norm_with_kernel = kernelize(
|
||||||
|
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
|
||||||
|
)
|
||||||
|
Y_kernel = rms_norm_with_kernel(X)
|
||||||
|
|
||||||
|
torch.testing.assert_close(Y_kernel, Y)
|
||||||
|
|
||||||
|
assert rms_norm.n_calls == 1
|
||||||
|
assert rms_norm_with_kernel.n_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
|
||||||
|
reason="Skip on xpu devices",
|
||||||
|
)
|
||||||
|
def test_rocm_kernel_mapping():
|
||||||
|
"""Test that ROCm shorthand device mapping works correctly."""
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"rocm": LayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test that the mapping is processed correctly
|
||||||
|
with use_kernel_mapping(kernel_layer_mapping, inherit_mapping=False):
|
||||||
|
mapping = _KERNEL_MAPPING.get()
|
||||||
|
|
||||||
|
# Verify the mapping exists
|
||||||
|
assert "SiluAndMul" in mapping
|
||||||
|
assert "rocm" in mapping["SiluAndMul"]
|
||||||
|
|
||||||
|
# Verify the repository is correctly stored
|
||||||
|
rocm_repos = mapping["SiluAndMul"]["rocm"]
|
||||||
|
assert rocm_repos is not None
|
||||||
|
assert (
|
||||||
|
rocm_repos.repos[Mode.FALLBACK]._repo_id == "kernels-community/activation"
|
||||||
|
)
|
||||||
|
assert rocm_repos.repos[Mode.FALLBACK].layer_name == "SiluAndMul"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
def test_capability():
|
def test_capability():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
@ -183,7 +296,74 @@ def test_layer_fallback_works():
|
|||||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
def test_local_layer_repo(device):
|
||||||
|
# Fetch a kernel to the local cache.
|
||||||
|
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")
|
||||||
|
|
||||||
|
linear = TorchLinearWithCounter(32, 32).to(device)
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Linear": {
|
||||||
|
device: LocalLayerRepository(
|
||||||
|
# install_kernel will give the fully-resolved path.
|
||||||
|
repo_path=path.parent.parent,
|
||||||
|
package_name=package_name,
|
||||||
|
layer_name="LinearBackward",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
inherit_mapping=False,
|
||||||
|
):
|
||||||
|
kernelize(linear, mode=Mode.INFERENCE)
|
||||||
|
|
||||||
|
X = torch.randn(10, 32, device=device)
|
||||||
|
linear(X)
|
||||||
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_stateful_layer(device):
|
||||||
|
@use_kernel_forward_from_hub("ReluWithHiddenSize")
|
||||||
|
class ReluWithHiddenSize(nn.Module):
|
||||||
|
hidden_size: int
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.relu(x)
|
||||||
|
|
||||||
|
model = ReluWithHiddenSize(hidden_size=64).to(device)
|
||||||
|
x = torch.randn((32, 64), device=device)
|
||||||
|
y_ref = model(x)
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"ReluWithHiddenSize": {
|
||||||
|
"cuda": LayerRepository(
|
||||||
|
repo_id="kernels-test/state-test",
|
||||||
|
layer_name="StatefulReLU",
|
||||||
|
),
|
||||||
|
"xpu": LayerRepository(
|
||||||
|
repo_id="kernels-test/state-test",
|
||||||
|
layer_name="StatefulReLU",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
inherit_mapping=False,
|
||||||
|
):
|
||||||
|
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device=device)
|
||||||
|
|
||||||
|
y = model(x)
|
||||||
|
assert_close(y, y_ref)
|
||||||
|
|
||||||
|
model = torch.compile(model, fullgraph=True)
|
||||||
|
y = model(x)
|
||||||
|
assert_close(y, y_ref)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
def test_torch_compile_layer_without_fallback(cls, device):
|
def test_torch_compile_layer_without_fallback(cls, device):
|
||||||
@ -214,7 +394,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
|
|||||||
torch.testing.assert_close(Y_compiled, Y)
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||||
@pytest.mark.parametrize("device", ["cuda"])
|
@pytest.mark.parametrize("device", ["cuda"])
|
||||||
def test_torch_compile_layer_with_fallback(cls, device):
|
def test_torch_compile_layer_with_fallback(cls, device):
|
||||||
@ -237,12 +417,16 @@ def test_torch_compile_layer_with_fallback(cls, device):
|
|||||||
torch.testing.assert_close(Y_compiled, Y)
|
torch.testing.assert_close(Y_compiled, Y)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_mapping_contexts():
|
def test_mapping_contexts():
|
||||||
|
# Make sure we start from scratch.
|
||||||
|
register_kernel_mapping(kernel_layer_mapping, inherit_mapping=False)
|
||||||
|
|
||||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
}
|
}
|
||||||
|
|
||||||
extra_mapping1 = {
|
extra_mapping1 = {
|
||||||
@ -260,6 +444,7 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,10 +463,13 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
|
||||||
|
.repos[Mode.FALLBACK]
|
||||||
|
._repo_id
|
||||||
== "kernels-community/non-existing"
|
== "kernels-community/non-existing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -289,10 +477,11 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
|
||||||
== "kernels-community/activation"
|
== "kernels-community/activation"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -301,7 +490,9 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"]
|
||||||
|
.repos[Mode.FALLBACK]
|
||||||
|
._repo_id
|
||||||
== "kernels-community/non-existing"
|
== "kernels-community/non-existing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -309,10 +500,11 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
"TestKernel",
|
"TestKernel",
|
||||||
}
|
}
|
||||||
assert (
|
assert (
|
||||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK]._repo_id
|
||||||
== "kernels-community/activation"
|
== "kernels-community/activation"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -320,6 +512,7 @@ def test_mapping_contexts():
|
|||||||
"SiluAndMul",
|
"SiluAndMul",
|
||||||
"SiluAndMulStringDevice",
|
"SiluAndMulStringDevice",
|
||||||
"SiluAndMulNoCompile",
|
"SiluAndMulNoCompile",
|
||||||
|
"LigerRMSNorm",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -329,29 +522,46 @@ def test_validate_kernel_layer():
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.foo = 42
|
self.foo = 42
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="not override"):
|
def stub_repo(layer):
|
||||||
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
|
return LayerRepository(
|
||||||
|
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
TypeError,
|
||||||
|
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
|
||||||
|
):
|
||||||
|
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))
|
||||||
|
|
||||||
class BadLayer2(nn.Module):
|
class BadLayer2(nn.Module):
|
||||||
foo: int = 42
|
foo: int = 42
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="not contain additional members"):
|
with pytest.raises(
|
||||||
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
|
TypeError,
|
||||||
|
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
|
||||||
|
):
|
||||||
|
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))
|
||||||
|
|
||||||
class BadLayer3(nn.Module):
|
class BadLayer3(nn.Module):
|
||||||
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="different number of arguments"):
|
with pytest.raises(
|
||||||
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
|
TypeError,
|
||||||
|
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
|
||||||
|
):
|
||||||
|
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))
|
||||||
|
|
||||||
class BadLayer4(nn.Module):
|
class BadLayer4(nn.Module):
|
||||||
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="different kind of arguments"):
|
with pytest.raises(
|
||||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
TypeError,
|
||||||
|
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
|
||||||
|
):
|
||||||
|
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_invalid_mode_for_mapping_rejected():
|
def test_invalid_mode_for_mapping_rejected():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -371,7 +581,7 @@ def test_invalid_mode_for_mapping_rejected():
|
|||||||
kernelize(linear, mode=Mode.TRAINING)
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes():
|
def test_kernel_modes():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -400,11 +610,6 @@ def test_kernel_modes():
|
|||||||
linear(X)
|
linear(X)
|
||||||
assert linear.n_calls == 0
|
assert linear.n_calls == 0
|
||||||
|
|
||||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
||||||
kernelize(linear)
|
|
||||||
linear(X)
|
|
||||||
assert linear.n_calls == 0
|
|
||||||
|
|
||||||
# Case 2: register a kernel just for training. If no base kernel
|
# Case 2: register a kernel just for training. If no base kernel
|
||||||
# layer is registered, we fall back to the original layer.
|
# layer is registered, we fall back to the original layer.
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
@ -434,12 +639,6 @@ def test_kernel_modes():
|
|||||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||||
assert linear.n_calls == 1
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
||||||
kernelize(linear)
|
|
||||||
linear(X)
|
|
||||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
# Case 3: register a kernel just for training and one for fallback.
|
# Case 3: register a kernel just for training and one for fallback.
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
@ -461,23 +660,17 @@ def test_kernel_modes():
|
|||||||
X = torch.randn(10, 32, device="cuda")
|
X = torch.randn(10, 32, device="cuda")
|
||||||
linear(X)
|
linear(X)
|
||||||
# Falls back to TRAINING.
|
# Falls back to TRAINING.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
linear(X)
|
linear(X)
|
||||||
# Falls back to the TRAINING kernel.
|
# Falls back to the TRAINING kernel.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
linear(X)
|
linear(X)
|
||||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
||||||
kernelize(linear)
|
|
||||||
linear(X)
|
|
||||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
# Case 4: register a kernel with two preferences.
|
# Case 4: register a kernel with two preferences.
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
@ -497,25 +690,20 @@ def test_kernel_modes():
|
|||||||
X = torch.randn(10, 32, device="cuda")
|
X = torch.randn(10, 32, device="cuda")
|
||||||
linear(X)
|
linear(X)
|
||||||
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING)
|
kernelize(linear, mode=Mode.TRAINING)
|
||||||
linear(X)
|
linear(X)
|
||||||
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||||
linear(X)
|
linear(X)
|
||||||
# Uses TRAINING | TORCH_COMPILE kernel.
|
# Uses TRAINING | TORCH_COMPILE kernel.
|
||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 1
|
||||||
|
|
||||||
kernelize(linear)
|
|
||||||
linear(X)
|
|
||||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
|
||||||
assert linear.n_calls == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_fallback_used_when_training():
|
def test_fallback_used_when_training():
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
|
|
||||||
@ -580,7 +768,7 @@ def test_invalid_mode_rejected():
|
|||||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_inference():
|
def test_kernel_modes_inference():
|
||||||
"""Test inference-specific fallback scenarios."""
|
"""Test inference-specific fallback scenarios."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
@ -677,7 +865,7 @@ def test_kernel_modes_inference():
|
|||||||
assert linear.n_calls == 4
|
assert linear.n_calls == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_mixed():
|
def test_kernel_modes_mixed():
|
||||||
"""Test mixed training and inference kernel scenarios."""
|
"""Test mixed training and inference kernel scenarios."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
@ -767,7 +955,7 @@ def test_kernel_modes_mixed():
|
|||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.linux_only
|
@pytest.mark.cuda_only
|
||||||
def test_kernel_modes_cross_fallback():
|
def test_kernel_modes_cross_fallback():
|
||||||
"""Test cross-mode fallback scenarios from inference to training modes."""
|
"""Test cross-mode fallback scenarios from inference to training modes."""
|
||||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||||
@ -861,7 +1049,7 @@ def test_kernel_modes_cross_fallback():
|
|||||||
assert linear.n_calls == 2
|
assert linear.n_calls == 2
|
||||||
|
|
||||||
|
|
||||||
def test_layer_versions():
|
def test_layer_versions(device):
|
||||||
@use_kernel_forward_from_hub("Version")
|
@use_kernel_forward_from_hub("Version")
|
||||||
class Version(nn.Module):
|
class Version(nn.Module):
|
||||||
def forward(self) -> str:
|
def forward(self) -> str:
|
||||||
@ -872,20 +1060,20 @@ def test_layer_versions():
|
|||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||||
assert version() == "0.2.0"
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version="<1.0.0",
|
version="<1.0.0",
|
||||||
@ -893,13 +1081,13 @@ def test_layer_versions():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||||
assert version() == "0.2.0"
|
assert version() == "0.2.0"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version="<0.2.0",
|
version="<0.2.0",
|
||||||
@ -907,13 +1095,13 @@ def test_layer_versions():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||||
assert version() == "0.1.1"
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version=">0.1.0,<0.2.0",
|
version=">0.1.0,<0.2.0",
|
||||||
@ -921,13 +1109,13 @@ def test_layer_versions():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
version = kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||||
assert version() == "0.1.1"
|
assert version() == "0.1.1"
|
||||||
|
|
||||||
with use_kernel_mapping(
|
with use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
version=">0.2.0",
|
version=">0.2.0",
|
||||||
@ -936,13 +1124,13 @@ def test_layer_versions():
|
|||||||
}
|
}
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||||
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
kernelize(version, device=device, mode=Mode.INFERENCE)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||||
use_kernel_mapping(
|
use_kernel_mapping(
|
||||||
{
|
{
|
||||||
"Version": {
|
"Version": {
|
||||||
Device(type="cuda"): LayerRepository(
|
Device(type=device): LayerRepository(
|
||||||
repo_id="kernels-test/versions",
|
repo_id="kernels-test/versions",
|
||||||
layer_name="Version",
|
layer_name="Version",
|
||||||
revision="v0.1.0",
|
revision="v0.1.0",
|
||||||
|
|||||||
Reference in New Issue
Block a user