Compare commits

...

13 Commits

Author SHA1 Message Date
1be0a76e44 Set version to 0.6.0 2025-06-04 12:00:03 +00:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
848c6db87b Add support for Metal builds (#89)
* Add support for Metal builds

* Add Metal test, gate tests by OS where necessary
2025-05-30 15:54:28 +02:00
fabb8c52d1 Add generate-readme subcommand for generating a README (#88)
* Add `generate-readme` subcommand for generating a README

This README includes all the top-level functions with docs (if
docstrings are available).

* CI: attempt README generation

* Add PyYAML dependencies

* Typing fixes
2025-05-21 15:43:53 +02:00
d66260dd83 kernels: add the to-wheel subcommand (#84)
* kernels: add the `to-wheel` subcommand

This subcommand accepts a kernel repo and version as arguments:

    kernels to-wheel kernels-community/activation 0.0.3

Wheels will then be generated for every build variant.

* CI: check kernel -> wheel conversion

* No typing for wheel.wheelfile
2025-05-08 17:30:06 +02:00
daac8078fc CI: fix some stubs (#83) 2025-05-07 14:43:57 +02:00
fcb9a80ce6 Set version to 0.5.0 (#82) 2025-05-06 11:45:26 +02:00
c25bb32e6e Add publishing workflow (#81) 2025-05-06 09:29:08 +00:00
2036892762 Allow layers to opt in to torch.compile (#79)
* Allow layers to opt in to `torch.compile`

This change allows a layer to set the `can_torch_compile` class
variable to indicate that the layer is compatible with `torch.compile`.
When enabled, the layer does not fall back to the original
implementation when `torch.compile` is used.

* Comment fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-06 09:36:33 +02:00
23 changed files with 1778 additions and 141 deletions

120
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,120 @@
name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI
on: push
jobs:
build:
name: Build distribution 📦
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install pypa/build
run: >-
python3 -m
pip install
build
--user
- name: Build a binary wheel and a source tarball
run: python3 -m build
- name: Store the distribution packages
uses: actions/upload-artifact@v4
with:
name: python-package-distributions
path: dist/
publish-to-pypi:
name: >-
Publish Python 🐍 distribution 📦 to PyPI
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
needs:
- build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
github-release:
name: >-
Sign the Python 🐍 distribution 📦 with Sigstore
and upload them to GitHub Release
needs:
- publish-to-pypi
runs-on: ubuntu-latest
permissions:
contents: write # IMPORTANT: mandatory for making GitHub Releases
id-token: write # IMPORTANT: mandatory for sigstore
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
./dist/*.whl
- name: Create GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
run: >-
gh release create
"$GITHUB_REF_NAME"
--repo "$GITHUB_REPOSITORY"
--notes ""
- name: Upload artifact signatures to GitHub Release
env:
GITHUB_TOKEN: ${{ github.token }}
# Upload to GitHub Release using the `gh` CLI.
# `dist/` contains the built packages, and the
# sigstore-produced signatures and certificates.
run: >-
gh release upload
"$GITHUB_REF_NAME" dist/**
--repo "$GITHUB_REPOSITORY"
publish-to-testpypi:
name: Publish Python 🐍 distribution 📦 to TestPyPI
needs:
- build
runs-on: ubuntu-latest
environment:
name: testpypi
url: https://test.pypi.org/p/kernels
permissions:
id-token: write # IMPORTANT: mandatory for trusted publishing
steps:
- name: Download all the dists
uses: actions/download-artifact@v4
with:
name: python-package-distributions
path: dist/
- name: Publish distribution 📦 to TestPyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
skip-existing: true # Only upload when the version is unique.

View File

@ -24,7 +24,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ["3.10", "3.12"]
torch-version: ["2.5.1", "2.6.0"]
torch-version: ["2.6.0", "2.7.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
@ -53,6 +53,18 @@ jobs:
- name: Run tests
run: uv run pytest tests
- name: Check kernel conversion
run: |
uv pip install wheel
uv run kernels to-wheel kernels-community/triton-layer-norm 0.0.1
uv pip install triton_layer_norm-0.0.1*.whl
uv run python -c "import triton_layer_norm"
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
- name: Import check without torch
run: |
uv pip uninstall torch

View File

@ -61,4 +61,5 @@ the Hub.
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

13
docs/faq.md Normal file
View File

@ -0,0 +1,13 @@
# FAQ
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

View File

@ -109,9 +109,12 @@ requirements:
- The `forward` method has a signature that is compatible with the
`forward` method that it is extending.
The only exception to the _no class variables rule_ is addition of a
`has_backward` class variable. This variable is used to indicate whether
the layer has a backward pass implemented (`True` when absent).
There are two exceptions to the _no class variables rule_:
1. The `has_backward` variable can be used to indicate whether the layer has
a backward pass implemented (`True` when absent).
2. The `can_torch_compile` variable can be used to indicate whether the layer
supports `torch.compile` (`False` when absent).
This is an example of a pure layer:

View File

@ -23,33 +23,84 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda")
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, pass the `needs_torch_compile` argument to ensure that
only kernels that support `torch.compile` will be loaded:
```python
model = MyModel(...)
model = kernelize(model, needs_torch_compile=True)
```
### Fallback forward
The `needs_torch_compile` argument will fall back to the layer's original
`forward` if the registered kernels does not support `torch.compile`. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, needs_torch_compile=True, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
`kernelize`` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
@ -61,7 +112,11 @@ kernel_layer_mapping = {
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
@ -72,7 +127,7 @@ used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the

56
flake.lock generated
View File

@ -51,18 +51,39 @@
"type": "github"
}
},
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1749025620,
"narHash": "sha256-V/r5KOp8FRC5n3MINDzTeS3pZz57SasFVzx12WQRQ8U=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "7ab84ffad440c530162f528a96fa062530a6c8e4",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "torch-cxx11",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
@ -70,11 +91,11 @@
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"tgi-nix",
"hf-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
]
}
},
"systems": {
@ -106,27 +127,6 @@
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,7 +1,7 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
hf-nix.url = "github:huggingface/hf-nix/torch-cxx11";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
@ -9,21 +9,21 @@
self,
nixpkgs,
flake-utils,
tgi-nix,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
inherit (hf-nix.lib) config;
overlays = [
tgi-nix.overlays.default
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =
@ -34,10 +34,13 @@
ruff
]
++ (with python3.pkgs; [
docutils
huggingface-hub
pytest
pytest-benchmark
pyyaml
torch
types-pyyaml
venvShellHook
]);

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.4.4"
version = "0.6.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -14,6 +14,7 @@ requires-python = ">= 3.9"
dependencies = [
"huggingface_hub>=0.26.0,<1.0",
"packaging>=20.0",
"pyyaml>=6",
"tomli>=2.0; python_version<'3.11'",
]
@ -28,6 +29,7 @@ dev = [
# Whatever version is compatible with pytest.
"pytest-benchmark",
"torch >=2.5",
"types-pyyaml"
]
[project.optional-dependencies]

4
pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
markers =
darwin_only: marks tests that should only run on macOS
linux_only: marks tests that should only run on Linux

View File

@ -1,6 +1,7 @@
from kernels.layer import (
Device,
LayerRepository,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
@ -26,4 +27,5 @@ __all__ = [
"replace_kernel_forward_from_hub",
"LayerRepository",
"Device",
"kernelize",
]

View File

@ -0,0 +1,751 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Vendored from https://github.com/huggingface/doc-builder/blob/main/src/doc_builder/convert_rst_to_mdx.py
import re
# Re pattern to catch things inside ` ` in :obj:`thing`.
_re_obj = re.compile(r":obj:`([^`]+)`")
# Re pattern to catch things inside ` ` in :math:`thing`.
_re_math = re.compile(r":math:`([^`]+)`")
# Re pattern to catch things between single backquotes.
_re_single_backquotes = re.compile(r"(^|[^`])`([^`]+)`([^`]|$)")
# Re pattern to catch things between double backquotes.
_re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
# Re pattern to catch things inside ` ` in :func/class/meth:`thing`.
_re_func_class = re.compile(r":(?:func|class|meth):`([^`]+)`")
def convert_rst_formatting(text):
"""
Convert rst syntax for formatting to markdown in a given text.
"""
# Remove :class:, :func: and :meth: markers. To code-links and put double backquotes
# (to not be caught by the italic conversion).
text = _re_func_class.sub(r"[``\1``]", text)
# Remove :obj: markers. What's after is in a single backquotes so we put in double backquotes
# (to not be caught by the italic conversion).
text = _re_obj.sub(r"``\1``", text)
# Remove :math: markers.
text = _re_math.sub(r"\\\\(\1\\\\)", text)
# Convert content in single backquotes to italic.
text = _re_single_backquotes.sub(r"\1*\2*\3", text)
# Convert content in double backquotes to single backquotes.
text = _re_double_backquotes.sub(r"\1`\2`\3", text)
# Remove remaining ::
text = re.sub(r"::\n", "", text)
# Remove new lines inside blocks in backsticks as they will be kept.
lines = text.split("\n")
in_code = False
text = None
for line in lines:
if in_code:
splits = line.split("`")
in_code = len(splits) > 1 and len(splits) % 2 == 1
if len(splits) == 1:
# Some forgotten lone backstick
text += "\n" + line
else:
text += " " + line.lstrip()
else:
if text is not None:
text += "\n" + line
else:
text = line
splits = line.split("`")
in_code = len(splits) % 2 == 0
return text
# Re pattern to catch description and url in links of the form `description <url>`_.
_re_links = re.compile(r"`([^`]+\S)\s+</*([^/][^>`]*)>`_+")
# Re pattern to catch description and url in links of the form :prefix_link:`description <url>`_.
_re_prefix_links = re.compile(r":prefix_link:`([^`]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :doc:`reference`.
_re_simple_doc = re.compile(r":doc:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :doc:`description <reference>`.
_re_doc_with_description = re.compile(r":doc:`([^`<]+\S)\s+</*([^/][^>`]*)>`")
# Re pattern to catch reference in links of the form :ref:`reference`.
_re_simple_ref = re.compile(r":ref:`([^`<]*)`")
# Re pattern to catch description and reference in links of the form :ref:`description <reference>`.
_re_ref_with_description = re.compile(r":ref:`([^`<]+\S)\s+<([^>]*)>`")
def convert_rst_links(text, page_info):
"""
Convert the rst links in text to markdown.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
no_prefix = page_info.get("no_prefix", False)
prefix = "" if no_prefix else f"/docs/{package_name}/{version}/{language}/"
# Links of the form :doc:`page`
text = _re_simple_doc.sub(rf"[\1]({prefix}\1)", text)
# Links of the form :doc:`text <page>`
text = _re_doc_with_description.sub(rf"[\1]({prefix}\2)", text)
if "page" in page_info and not no_prefix:
page = str(page_info["page"])
if page.endswith(".html"):
page = page[:-5]
prefix = f"{prefix}{page}"
else:
prefix = ""
# Refs of the form :ref:`page`
text = _re_simple_ref.sub(rf"[\1]({prefix}#\1)", text)
# Refs of the form :ref:`text <page>`
text = _re_ref_with_description.sub(rf"[\1]({prefix}#\2)", text)
# Links with a prefix
# TODO: when it exists, use the API to deal with prefix links properly.
prefix = f"https://github.com/huggingface/{package_name}/tree/main/"
text = _re_prefix_links.sub(rf"[\1]({prefix}\2)", text)
# Other links
text = _re_links.sub(r"[\1](\2)", text)
# Relative links or Transformers links need to remove the .html
if (
"(https://https://huggingface.co/" in text
or re.search(r"\(\.+/", text) is not None
):
text = text.replace(".html", "")
return text
# Re pattern that catches examples blocks of the form `Example::`.
_re_example = re.compile(r"^\s*(\S.*)::\s*$")
# Re pattern that catches rst blocks of the form `.. block_name::`.
_re_block = re.compile(r"^\s*\.\.\s+(\S+)::")
# Re pattern that catches what's after the :: in rst blocks of the form `.. block_name:: something`.
_re_block_info = re.compile(r"^\s*\.\.\s+\S+::\s*(\S.*)$")
def is_empty_line(line):
return len(line) == 0 or line.isspace()
def find_indent(line):
"""
Returns the number of spaces that start a line indent.
"""
search = re.search(r"^(\s*)(?:\S|$)", line)
if search is None:
return 0
return len(search.groups()[0])
_re_rst_option = re.compile(r"^\s*:(\S+):(.*)$")
def convert_special_chars(text):
"""
Converts { and < that have special meanings in MDX.
"""
text = text.replace("{", "&amp;lcub;")
# We don't want to replace those by the HTML code, so we temporarily set them at LTHTML
text = re.sub(
r"<(img|br|hr|Youtube)", r"LTHTML\1", text
) # html void elements with no closing counterpart
_re_lt_html = re.compile(r"<(\S+)([^>]*>)(((?!</\1>).)*)<(/\1>)", re.DOTALL)
while _re_lt_html.search(text):
text = _re_lt_html.sub(r"LTHTML\1\2\3LTHTML\5", text)
text = re.sub(r"(^|[^<])<([^<]|$)", r"\1&amp;lt;\2", text)
text = text.replace("LTHTML", "<")
return text
def parse_options(block_content):
"""
Parses the option in some rst block content.
"""
block_lines = block_content.split("\n")
block_indent = find_indent(block_lines[0])
current_option = None
result = {}
for line in block_lines:
if _re_rst_option.search(line) is not None:
current_option, value = _re_rst_option.search(line).groups()
result[current_option] = value.lstrip()
elif find_indent(line) > block_indent:
result[current_option] += " " + line.lstrip()
return result
def apply_min_indent(text, min_indent):
"""
Make sure all lines in a text are have a minimum indentation.
Args:
text (`str`): The text to treat.
min_indent (`int`): The minimal indentation.
Returns:
`str`: The processed text.
"""
lines = text.split("\n")
idx = 0
while idx < len(lines):
if is_empty_line(lines[idx]):
idx += 1
continue
indent = find_indent(lines[idx])
if indent < min_indent:
while idx < len(lines) and (
find_indent(lines[idx]) >= indent or is_empty_line(lines[idx])
):
if not is_empty_line(lines[idx]):
lines[idx] = " " * (min_indent - indent) + lines[idx]
idx += 1
else:
idx += 1
return "\n".join(lines)
def convert_rst_blocks(text, page_info):
"""
Converts rst special blocks (examples, notes) into MDX.
"""
if "package_name" not in page_info:
raise ValueError("`page_info` must contain at least the package_name.")
package_name = page_info["package_name"]
version = page_info.get("version", "main")
language = page_info.get("language", "en")
lines = text.split("\n")
idx = 0
new_lines = []
while idx < len(lines):
block_type = None
block_info = None
if _re_block.search(lines[idx]) is not None:
block_type = _re_block.search(lines[idx]).groups()[0]
if _re_block_info.search(lines[idx]) is not None:
block_info = _re_block_info.search(lines[idx]).groups()[0]
elif _re_example.search(lines[idx]) is not None:
block_type = "code-block-example"
block_info = "python"
example_name = _re_example.search(lines[idx]).groups()[0]
new_lines.append(f"<exampletitle>{example_name}:</exampletitle>\n")
elif lines[idx].strip() == "..":
block_type = "comment"
elif lines[idx].strip() == "::":
block_type = "code-block"
if block_type is not None:
block_indent = find_indent(lines[idx])
# Find the next nonempty line
idx += 1
while idx < len(lines) and is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it (or has already)
example_indent = (
find_indent(lines[idx]) if idx < len(lines) else block_indent
)
if example_indent == block_indent:
block_content = ""
else:
block_lines = []
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= example_indent
):
block_lines.append(lines[idx][example_indent:])
idx += 1
block_content = "\n".join(block_lines)
if block_type in ["code", "code-block"]:
prefix = "```" if block_info is None else f"```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n")
elif block_type == "code-block-example":
prefix = f"<example>```{block_info}"
new_lines.append(f"{prefix}\n{block_content.strip()}\n```\n</example>")
elif block_type == "note":
new_lines.append(
apply_min_indent(
f"<Tip>\n\n{block_content.strip()}\n\n</Tip>\n", block_indent
)
)
elif block_type == "warning":
new_lines.append(
apply_min_indent(
"<Tip warning={true}>\n\n"
+ f"{block_content.strip()}\n\n</Tip>\n",
block_indent,
)
)
elif block_type == "raw":
new_lines.append(block_content.strip() + "\n")
elif block_type == "math":
new_lines.append(f"$${block_content.strip()}$$\n")
elif block_type == "comment":
new_lines.append(f"<!--{block_content.strip()}\n-->\n")
elif block_type == "autofunction":
if block_info is not None:
new_lines.append(f"[[autodoc]] {block_info}\n")
elif block_type == "autoclass":
if block_info is not None:
block = f"[[autodoc]] {block_info}\n"
options = parse_options(block_content)
if "special-members" in options:
special_members = options["special-members"].split(", ")
for special_member in special_members:
block += f" - {special_member}\n"
if "members" in options:
members = options["members"]
if len(members) == 0:
block += " - all\n"
else:
for member in members.split(", "):
block += f" - {member}\n"
new_lines.append(block)
elif block_type == "image":
options = parse_options(block_content)
target = options.pop("target", None)
if block_info is not None:
options["src"] = block_info
else:
if target is None:
raise ValueError("Image source not defined.")
options["src"] = target
# Adapt path
options["src"] = options["src"].replace(
"/imgs/", f"/docs/{package_name}/{version}/{language}/imgs/"
)
html_code = " ".join(
[f'{key}="{value}"' for key, value in options.items()]
)
new_lines.append(f"<img {html_code}/>\n")
else:
new_lines.append(
f"{block_type},{block_info}\n{block_content.rstrip()}\n"
)
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
# Re pattern that catches rst args blocks of the form `Parameters:`.
_re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
# Re pattern that catches return blocks of the form `Return:`.
_re_returns = re.compile(r"^\s*(Return|Yield|Raise)s?:\s*$")
def split_return_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
if len(splits_on_colon) % 2 == 1 and re.search(r"`\w+`$", line.rstrip()):
return line, ""
return None, line
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
def split_raise_line(line):
"""
Split the raise line with format `SomeError some doc`.
"""
splits_on_colon = line.strip().split(" ")
error_type, doc = splits_on_colon[0], " ".join(splits_on_colon[1:])
if error_type and error_type[-1] == ":":
error_type = error_type[:-1]
return error_type, doc
def split_arg_line(line):
"""
Split the return line with format `type: some doc`. Type may contain colons in the form of :obj: or :class:.
"""
splits_on_colon = line.split(":")
idx = 1
while idx < len(splits_on_colon) and splits_on_colon[idx] in ["obj", "class"]:
idx += 2
if idx >= len(splits_on_colon):
return line, ""
return ":".join(splits_on_colon[:idx]), ":".join(splits_on_colon[idx:])
class InvalidRstDocstringError(ValueError):
pass
_re_parameters = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_re_md_link = re.compile(r"\[(.+)\]\(.+\)", re.DOTALL)
def parse_rst_docstring(docstring):
"""
Parses a docstring written in rst, in particular the list of arguments and the return type.
"""
lines = docstring.split("\n")
idx = 0
while idx < len(lines):
# Parameters section
if _re_args.search(lines[idx]) is not None:
# Title of the section.
lines[idx] = "<parameters>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the list of parameters, this block will stop when we unindent under it or we see the
# Returns or Raises block.
param_indent = find_indent(lines[idx])
while (
idx < len(lines)
and find_indent(lines[idx]) == param_indent
and _re_returns.search(lines[idx]) is None
):
intro, doc = split_arg_line(lines[idx])
# Line starting with a > after indent indicate a "section title" in the parameters.
if intro.lstrip().startswith(">"):
lines[idx] = intro.lstrip()
else:
lines[idx] = (
re.sub(r"^\s*(\S+)(\s)?", r"- **\1**\2", intro) + " --" + doc
)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx]) or find_indent(lines[idx]) > param_indent
):
idx += 1
lines.insert(idx, "</parameters>\n")
idx += 1
# Returns section
elif _re_returns.search(lines[idx]) is not None:
# tag is either `return` or `yield`
tag = _re_returns.match(lines[idx]).group(1).lower()
# Title of the section.
lines[idx] = f"<{tag}s>\n"
# Find the next nonempty line
idx += 1
while is_empty_line(lines[idx]):
idx += 1
# Grab the indent of the return line, this block will stop when we unindent under it.
return_indent = find_indent(lines[idx])
raised_errors = []
# The line may contain the return type.
if tag in ["return", "yield"]:
return_type, return_description = split_return_line(lines[idx])
lines[idx] = return_description
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) >= return_indent
):
idx += 1
else:
while idx < len(lines) and find_indent(lines[idx]) == return_indent:
return_type, return_description = split_raise_line(lines[idx])
raised_error = re.sub(r"^\s*`?([\w\.]*)`?$", r"``\1``", return_type)
lines[idx] = "- " + raised_error + " -- " + return_description
md_link = _re_md_link.match(raised_error)
if md_link:
raised_error = md_link[1]
raised_error = re.sub(
r"^\s*`?([\w\.]*)`?$", r"``\1``", raised_error
)
if raised_error not in raised_errors:
raised_errors.append(raised_error)
idx += 1
while idx < len(lines) and (
is_empty_line(lines[idx])
or find_indent(lines[idx]) > return_indent
):
idx += 1
lines.insert(idx, f"</{tag}s>\n")
idx += 1
# Return block finished, we insert the return type if one was specified
if tag in ["return", "yield"] and return_type is not None:
lines[idx - 1] += f"\n<{tag}type>{return_type}</{tag}type>\n"
elif len(raised_errors) > 0:
# raised errors
lines[
idx - 1
] += f"\n<raisederrors>{' or '.join(raised_errors)}</raisederrors>\n"
else:
idx += 1
result = "\n".join(lines)
# combine multiple <parameters> blocks into one block
if result.count("<parameters>") > 1:
parameters_blocks = _re_parameters.findall(result)
parameters_blocks = [pb[0].strip() for pb in parameters_blocks]
parameters_str = "\n".join(parameters_blocks)
result = _re_parameters.sub("", result)
result += f"\n<parameters>{parameters_str}</parameters>\n"
return result
_re_list = re.compile(r"^\s*(-|\*|\d+\.)\s")
_re_autodoc = re.compile(r"^\s*\[\[autodoc\]\]\s+(\S+)\s*$")
def remove_indent(text):
"""
Remove indents in text, except the one linked to lists (or sublists).
"""
lines = text.split("\n")
# List of indents to remember for nested lists
current_indents = []
# List of new indents to remember for nested lists
new_indents = []
is_inside_code = False
code_indent = 0
for idx, line in enumerate(lines):
# Line is an item in a list.
if _re_list.search(line) is not None:
indent = find_indent(line)
# Is it a new list / new level of nestedness?
if len(current_indents) == 0 or indent > current_indents[-1]:
current_indents.append(indent)
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Otherwise it's an existing level of list (current one, or previous one)
else:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] != indent:
level -= 1
current_indents = current_indents[: level + 1]
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indent += len(_re_list.search(line).groups()[0]) + 1
new_indents.append(new_indent)
# Line is an autodoc, we keep the indent for the list just after if there is one.
elif _re_autodoc.search(line) is not None:
indent = find_indent(line)
current_indents = [indent]
new_indents = [4]
lines[idx] = line.strip()
# Deal with empty lines separately
elif is_empty_line(line):
lines[idx] = ""
# Code blocks
elif line.lstrip().startswith("```"):
is_inside_code = not is_inside_code
if is_inside_code:
code_indent = find_indent(line)
lines[idx] = line[code_indent:]
elif is_inside_code:
lines[idx] = line[code_indent:]
else:
indent = find_indent(line)
if len(current_indents) > 0 and indent > current_indents[-1]:
lines[idx] = " " * new_indents[-1] + line[indent:]
elif len(current_indents) > 0:
# Let's find the proper level of indentation
level = len(current_indents) - 1
while level >= 0 and current_indents[level] > indent:
level -= 1
current_indents = current_indents[: level + 1]
if level >= 0:
if current_indents[level] < indent:
new_indents = new_indents[: level + 1]
else:
new_indents = new_indents[:level]
new_indent = 0 if len(new_indents) == 0 else new_indents[-1]
lines[idx] = " " * new_indent + line[indent:]
new_indents.append(new_indent)
else:
new_indents = []
lines[idx] = line[indent:]
else:
lines[idx] = line[indent:]
return "\n".join(lines)
def base_rst_to_mdx(text, page_info, unindent=True):
"""
Convert a text from rst to mdx, with the base operations necessary for both docstrings and rst docs.
"""
text = convert_rst_links(text, page_info)
text = convert_special_chars(text)
text = convert_rst_blocks(text, page_info)
# Convert * in lists to - to avoid the formatting conversion treat them as bold.
text = re.sub(r"^(\s*)\*(\s)", r"\1-\2", text, flags=re.MULTILINE)
text = convert_rst_formatting(text)
return remove_indent(text) if unindent else text
def convert_rst_docstring_to_mdx(docstring, page_info):
"""
Convert a docstring written in rst to mdx.
"""
text = parse_rst_docstring(docstring)
return base_rst_to_mdx(text, page_info)
def process_titles(lines):
"""Converts rst titles to markdown titles."""
title_chars = """= - ` : ' " ~ ^ _ * + # < >""".split(" ")
title_levels = {}
new_lines = []
for line in lines:
if (
len(new_lines) > 0
and len(line) >= len(new_lines[-1])
and len(set(line)) == 1
and line[0] in title_chars
and line != "::"
):
char = line[0]
level = title_levels.get(char, len(title_levels) + 1)
if level not in title_levels:
title_levels[char] = level
new_lines[-1] = f"{'#' * level} {new_lines[-1]}"
else:
new_lines.append(line)
return new_lines
# Matches lines with a pattern of a table new line in rst.
_re_ignore_line_table = re.compile(r"^(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a table new line in rst, with a first column empty.
_re_ignore_line_table1 = re.compile(r"^\|\s+(\+[\-\s]+)+\+\s*$")
# Matches lines with a pattern of a first table line in rst.
_re_sep_line_table = re.compile(r"^(\+[=\s]+)+\+\s*$")
# Re pattern that catches anchors of the type .. reference:
_re_anchor_section = re.compile(r"^\.\.\s+_(\S+):")
def split_pt_tf_code_blocks(text):
"""
Split PyTorch and TensorFlow specific block codes.
"""
lines = text.split("\n")
new_lines = []
idx = 0
while idx < len(lines):
if lines[idx].startswith("```"):
code_lines = {"common": [lines[idx]], "pytorch": [], "tensorflow": []}
is_pytorch = False
is_tensorflow = False
idx += 1
while idx < len(lines) and lines[idx].strip() != "```":
if "## PYTORCH CODE" in lines[idx]:
is_pytorch = True
is_tensorflow = False
elif "## TENSORFLOW CODE" in lines[idx]:
is_tensorflow = True
is_pytorch = False
elif is_pytorch:
code_lines["pytorch"].append(lines[idx])
elif is_tensorflow:
code_lines["tensorflow"].append(lines[idx])
else:
code_lines["common"].append(lines[idx])
idx += 1
if len(code_lines["pytorch"]) > 0 or len(code_lines["tensorflow"]) > 0:
block_lines = ["<frameworkcontent>", "<pt>"]
block_lines.extend(code_lines["common"].copy() + code_lines["pytorch"])
block_lines.extend(["```", "</pt>", "<tf>"])
block_lines.extend(
code_lines["common"].copy() + code_lines["tensorflow"]
)
block_lines.extend(["```", "</tf>", "</frameworkcontent>"])
new_lines.extend(block_lines)
else:
block_lines = code_lines["common"] + ["```"]
new_lines.extend(block_lines)
idx += 1
else:
new_lines.append(lines[idx])
idx += 1
return "\n".join(new_lines)
def convert_rst_to_mdx(rst_text, page_info, add_imports=True):
"""
Convert a document written in rst to mdx.
"""
lines = rst_text.split("\n")
lines = process_titles(lines)
if add_imports:
new_lines = [
'<script lang="ts">',
' import Tip from "$lib/Tip.svelte";',
' import Youtube from "$lib/Youtube.svelte";',
' import Docstring from "$lib/Docstring.svelte";',
' import CodeBlock from "$lib/CodeBlock.svelte";',
' import CodeBlockFw from "$lib/CodeBlockFw.svelte";',
' import DocNotebookDropdown from "$lib/DocNotebookDropdown.svelte";',
' import CourseFloatingBanner from "$lib/CourseFloatingBanner.svelte";',
' import IconCopyLink from "$lib/IconCopyLink.svelte";',
' import FrameworkContent from "$lib/FrameworkContent.svelte";',
' import Markdown from "$lib/Markdown.svelte";',
' import ExampleCodeBlock from "$lib/ExampleCodeBlock.svelte";',
' import Added from "$lib/Added.svelte";',
' import Changed from "$lib/Changed.svelte";',
' import Deprecated from "$lib/Deprecated.svelte";',
' import PipelineIcon from "$lib/PipelineIcon.svelte";',
' import PipelineTag from "$lib/PipelineTag.svelte";',
" ",
' export let fw: "pt" | "tf"',
"</script>",
"<svelte:head>",
'<meta name="hf:doc:metadata" content={JSON.stringify(metadata)} >',
"</svelte:head>",
"",
]
else:
new_lines = []
for line in lines:
if _re_ignore_line_table.search(line) is not None:
continue
elif _re_ignore_line_table1.search(line) is not None:
continue
elif _re_sep_line_table.search(line) is not None:
line = line.replace("=", "-").replace("+", "|")
elif _re_anchor_section.search(line) is not None:
anchor_name = _re_anchor_section.search(line).groups()[0]
line = f"<a id='{anchor_name}'></a>"
new_lines.append(line)
text = "\n".join(new_lines)
return split_pt_tf_code_blocks(base_rst_to_mdx(text, page_info))

View File

@ -8,6 +8,9 @@ from kernels.compat import tomllib
from kernels.lockfile import KernelLock, get_kernel_locks
from kernels.utils import install_kernel, install_kernel_all_variants
from .doc import generate_readme_for_kernel
from .wheel import build_variant_to_wheel
def main():
parser = argparse.ArgumentParser(
@ -36,6 +39,47 @@ def main():
)
lock_parser.set_defaults(func=lock_kernels)
to_wheel_parser = subparsers.add_parser(
"to-wheel", help="Convert a kernel to a wheel file"
)
to_wheel_parser.add_argument("repo_id", type=str, help="The kernel repo ID")
to_wheel_parser.add_argument("version", type=str, help="The kernel version")
to_wheel_parser.add_argument(
"--python-version",
type=str,
default="3.9",
help="The minimum Python version. Must match the Python version that the kernel was compiled for.",
)
to_wheel_parser.add_argument(
"--manylinux-version",
type=str,
default="2.28",
help="The manylinux version. Must match the manylinux version that the kernel was compiled for.",
)
to_wheel_parser.set_defaults(func=kernels_to_wheel)
# Add generate-readme subcommand parser
generate_readme_parser = subparsers.add_parser(
"generate-readme",
help="Generate README snippets for a kernel's public functions",
)
generate_readme_parser.add_argument(
"repo_id",
type=str,
help="The kernel repo ID (e.g., kernels-community/activation)",
)
generate_readme_parser.add_argument(
"--revision",
type=str,
default="main",
help="The kernel revision (branch, tag, or commit SHA, defaults to 'main')",
)
generate_readme_parser.set_defaults(
func=lambda args: generate_readme_for_kernel(
repo_id=args.repo_id, revision=args.revision
)
)
args = parser.parse_args()
args.func(args)
@ -77,6 +121,24 @@ def download_kernels(args):
sys.exit(1)
def kernels_to_wheel(args):
variants_path = install_kernel_all_variants(
repo_id=args.repo_id, revision=f"v{args.version}"
)
for variant_path in variants_path.iterdir():
if not variant_path.is_dir():
continue
wheel_path = build_variant_to_wheel(
manylinux_version=args.manylinux_version,
python_version=args.python_version,
repo_id=args.repo_id,
version=args.version,
variant_path=variant_path,
wheel_dir=Path("."),
)
print(f"☸️ {wheel_path.name}", file=sys.stderr)
def lock_kernels(args):
with open(args.project_dir / "pyproject.toml", "rb") as f:
data = tomllib.load(f)

124
src/kernels/doc.py Normal file
View File

@ -0,0 +1,124 @@
import inspect
import re
import sys
from types import ModuleType
import yaml
from ._vendored.convert_rst_to_mdx import convert_rst_docstring_to_mdx
from .utils import get_kernel
_RE_PARAMETERS = re.compile(
r"<parameters>(((?!<parameters>).)*)</parameters>", re.DOTALL
)
_RE_RETURNS = re.compile(r"<returns>(((?!<returns>).)*)</returns>", re.DOTALL)
_RE_RETURNTYPE = re.compile(
r"<returntype>(((?!<returntype>).)*)</returntype>", re.DOTALL
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType):
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
else:
if "kernel" not in metadata["tags"]:
metadata["tags"].append("kernel")
print("---")
print(yaml.dump(metadata), end="")
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str):
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
print(f"# {title.strip()}")
print(
f"\n{convert_rst_docstring_to_mdx(rest.strip(), page_info={'package_name': kernel_name})}"
)
def generate_function_doc(kernel_module, kernel_name):
functions_info = []
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ == kernel_module.__name__:
# Exclude private functions.
if not name.startswith("_"):
try:
sig = inspect.signature(func)
docstring = inspect.getdoc(func) or "No documentation available."
functions_info.append((name, sig, docstring))
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
print("\n## Functions")
if not functions_info:
print(
"\nNo public top-level functions.",
)
return
for name, sig, docstring in functions_info:
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
]
if positions:
first_tag_pos = min(positions)
# The function description is anything before the first tag.
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
else:
print(f"\n{docstring_mdx.strip()}")
# Extract parameters
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
print("\n### Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
# Extract return information
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
print("\n### Returns", file=sys.stdout)
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
)
print(
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
)
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")

View File

@ -4,13 +4,16 @@ import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, Union
from types import MethodType
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from .utils import get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
@ -53,6 +56,9 @@ class LayerRepository:
return hash((self.layer_name, self.repo_id, self.revision))
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
)
@ -87,11 +93,13 @@ def use_kernel_mapping(
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Allows one to register a mapping between a layer name the corresponding
kernel to use, depending on the device. This should be use in conjunction
with `kernelize`.
Exemple usage:
```python
@ -119,26 +127,70 @@ def register_kernel_mapping(
device_repo[new_device] = new_repo
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
fallback_forward = cls.forward
cached_layer: Dict[LayerRepository, nn.Module] = {}
def kernelize(
model: "nn.Module",
device: Optional[Union[str, "torch.device"]] = None,
needs_torch_compile: bool = False,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
needs_torch_compile: When set to `true`, only kernels that support
`torch.compile` will be loaded.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
needs_backward = self.training
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
"\n"
@ -148,26 +200,30 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, *args, **kwargs)
# Use device type string directly instead of Device object
repo = kernel.get(device_type)
repo = kernel.get(Device(type=device.type))
if repo is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
# Short-circuit if we already loaded the layer.
layer = cached_layer.get(repo, None)
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
_conditionally_replace_forward(
module=module,
layer=layer,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
continue
layer = _get_kernel_layer(
repo_id=repo.repo_id,
@ -175,41 +231,36 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
revision=repo.revision,
)
# We have to validate against the original signature.
orig_forward = cls.forward
try:
cls.forward = fallback_forward
_validate_layer(check_cls=cls, cls=layer)
finally:
cls.forward = orig_forward
# Validate the replacement layer against the class layer.
_validate_layer(check_cls=module_class, cls=layer)
cached_layer[repo] = layer
_CACHED_LAYER[repo] = layer
if needs_backward and not getattr(layer, "has_backward", True):
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
_conditionally_replace_forward(
module=module,
layer=layer,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
cls.forward = forward
return model
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
def use_kernel_forward_from_hub(layer_name: str):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
@ -226,13 +277,13 @@ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Mo
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
from torch import nn
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
@ -245,7 +296,8 @@ def _validate_layer(*, check_cls, cls):
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
if difference != set() and difference != {"has_backward"}:
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
# Check whether the forward signatures are similar.
@ -262,3 +314,62 @@ def _validate_layer(*, check_cls, cls):
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)
def _find_device(model: "nn.Module") -> Device:
try:
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return Device(type=param.device.type)
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
needs_torch_compile: bool,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback when the layer does not support:
# compilation/compile when needed.
# backward when needed
needs_fallback = needs_torch_compile and not getattr(
layer, "can_torch_compile", False
)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(
f"Available kernel does not fulfill requirements: needs_torch_compile={needs_torch_compile}"
)
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
import torch.nn as nn
module_class = type(module)
layer_with_backward = (
layer if getattr(layer, "has_backward", True) else module_class
)
def train(self, mode: bool = True) -> nn.Module:
super(type(self), self).train(mode)
if mode:
self.forward = MethodType(layer_with_backward.forward, self)
else:
self.forward = MethodType(layer.forward, self)
return self
module.train = MethodType(train, module) # type: ignore[method-assign]
# Trigger setting correct forward for the current state.
module.train(module.training)

View File

@ -43,14 +43,22 @@ def build_variant() -> str:
elif torch.version.hip is not None:
rocm_version = parse(torch.version.hip.split("-")[0])
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
elif torch.backends.mps.is_available():
compute_framework = "metal"
else:
raise AssertionError("Torch was not compiled with CUDA or ROCm enabled.")
raise AssertionError(
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
)
torch_version = parse(torch.__version__)
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
cpu = platform.machine()
os = platform.system().lower()
if os == "darwin":
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"

186
src/kernels/wheel.py Normal file
View File

@ -0,0 +1,186 @@
import email.policy
import os
from dataclasses import dataclass
from email.message import Message
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Optional
try:
KERNELS_VERSION = version("kernels")
except PackageNotFoundError:
KERNELS_VERSION = "unknown"
@dataclass
class Metadata:
name: str
version: str
cuda_version: Optional[str]
cxx_abi_version: Optional[str]
torch_version: Optional[str]
os: Optional[str]
platform: Optional[str]
@property
def is_universal(self) -> bool:
return self.platform is None
def build_variant_to_wheel(
repo_id: str,
*,
version: str,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Create a wheel file from the variant path.
"""
name = repo_id.split("/")[-1].replace("_", "-")
metadata = extract_metadata(name, version, variant_path)
return build_wheel(
metadata,
variant_path=variant_path,
wheel_dir=wheel_dir,
manylinux_version=manylinux_version,
python_version=python_version,
)
def extract_metadata(name: str, version: str, variant_path: Path) -> Metadata:
"""
Extract metadata from the variant path.
"""
if variant_path.name == "torch-universal":
return Metadata(
name=name,
version=version,
cuda_version=None,
cxx_abi_version=None,
torch_version=None,
os=None,
platform=None,
)
if not variant_path.name.startswith("torch"):
raise ValueError("Currently only conversion of Torch kernels is supported.")
variant_parts = variant_path.name.removeprefix("torch").split("-")
if len(variant_parts) != 5:
raise ValueError(f"Invalid variant name: {variant_path.name}")
torch_version = f"{variant_parts[0][:-1]}.{variant_parts[0][-1:]}"
cpp_abi_version = variant_parts[1].removeprefix("cxx")
cuda_version = variant_parts[2].removeprefix("cu")
platform = variant_parts[3].replace("-", "_")
os = variant_parts[4]
return Metadata(
name=name,
version=version,
cuda_version=cuda_version,
cxx_abi_version=cpp_abi_version,
torch_version=torch_version,
os=os,
platform=platform,
)
def build_wheel(
metadata: Metadata,
*,
variant_path: Path,
wheel_dir: Path,
manylinux_version: str = "2.28",
python_version: str = "3.9",
) -> Path:
"""
Build the wheel file.
"""
try:
from wheel.wheelfile import WheelFile # type: ignore
except ImportError:
raise ImportError(
"The 'wheel' package is required to build wheels. Please install it with: `pip install wheel`"
)
name = metadata.name.replace("-", "_")
python_version_flat = python_version.replace(".", "")
if metadata.is_universal:
python_tag = f"py{python_version_flat}"
abi_tag = "none"
platform_tag = "any"
wheel_filename = (
f"{name}-{metadata.version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
)
dist_info_dir_name = f"{name}-{metadata.version}.dist-info"
root_is_purelib = "true"
requires_dist_torch = "torch"
else:
python_tag = f"cp{python_version_flat}"
abi_tag = "abi3"
if (
metadata.torch_version is None
or metadata.cuda_version is None
or metadata.cxx_abi_version is None
or metadata.os is None
or metadata.platform is None
):
raise ValueError(
"Torch version, CUDA version, C++ ABI version, OS, and platform must be specified for non-universal wheels."
)
local_version = f"torch{metadata.torch_version.replace('.', '')}cu{metadata.cuda_version}cxx{metadata.cxx_abi_version}"
if metadata.os == "linux":
platform_tag = (
f"manylinux_{manylinux_version.replace('.', '_')}_{metadata.platform}"
)
else:
platform_tag = f"{metadata.os}_{metadata.platform.replace('-', '_')}"
wheel_filename = f"{name}-{metadata.version}+{local_version}-{python_tag}-{abi_tag}-{platform_tag}.whl"
dist_info_dir_name = f"{name}-{metadata.version}+{local_version}.dist-info"
root_is_purelib = "false"
requires_dist_torch = f"torch=={metadata.torch_version}.*"
wheel_path = wheel_dir / wheel_filename
wheel_msg = Message(email.policy.compat32)
wheel_msg.add_header("Wheel-Version", "1.0")
wheel_msg.add_header("Generator", f"kernels ({KERNELS_VERSION})")
wheel_msg.add_header("Root-Is-Purelib", root_is_purelib)
wheel_msg.add_header("Tag", f"{python_tag}-{abi_tag}-{platform_tag}")
metadata_msg = Message(email.policy.compat32)
metadata_msg.add_header("Metadata-Version", "2.1")
metadata_msg.add_header("Name", name)
metadata_msg.add_header("Version", metadata.version)
metadata_msg.add_header("Summary", f"{name} kernel")
metadata_msg.add_header("Requires-Python", ">=3.9")
metadata_msg.add_header("Requires-Dist", requires_dist_torch)
source_pkg_dir = variant_path / name
with WheelFile(wheel_path, "w") as wheel_file:
for root, dirnames, filenames in os.walk(source_pkg_dir):
for filename in filenames:
if filename.endswith(".pyc"):
continue
abs_filepath = os.path.join(root, filename)
entry_name = os.path.relpath(abs_filepath, variant_path)
wheel_file.write(abs_filepath, entry_name)
wheel_metadata_path = os.path.join(dist_info_dir_name, "WHEEL")
wheel_file.writestr(wheel_metadata_path, str(wheel_msg).encode("utf-8"))
metadata_path = os.path.join(dist_info_dir_name, "METADATA")
wheel_file.writestr(metadata_path, str(metadata_msg).encode("utf-8"))
return wheel_path

10
tests/conftest.py Normal file
View File

@ -0,0 +1,10 @@
import sys
import pytest
def pytest_runtest_setup(item):
if "linux_only" in item.keywords and not sys.platform.startswith("linux"):
pytest.skip("skipping Linux-only test on non-Linux platform")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")

View File

@ -1,54 +1,82 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
"hash_type": "git_lfs_concat"
}
}

View File

@ -9,6 +9,11 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@pytest.fixture
def universal_kernel():
return get_kernel("kernels-community/triton-scaled-mm")
@ -21,6 +26,7 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_fast(kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
@ -36,6 +42,15 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):
x = torch.arange(-10, 10, dtype=dtype, device="mps")
y = metal_kernel.relu(x)
assert torch.allclose(y, torch.relu(x))
@pytest.mark.linux_only
@pytest.mark.parametrize(
"kernel_exists",
[
@ -52,6 +67,7 @@ def test_has_kernel(kernel_exists):
assert has_kernel(repo_id, revision=revision) == kernel
@pytest.mark.linux_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")

View File

@ -16,18 +16,21 @@ def device():
return "cuda"
@pytest.mark.linux_only
def test_gelu_small(kernel, device, benchmark):
x = torch.randn(32, 32, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_medium(kernel, device, benchmark):
x = torch.randn(128, 128, dtype=torch.float16, device=device)
y = torch.empty_like(x)
benchmark(kernel.gelu_fast, y, x)
@pytest.mark.linux_only
def test_gelu_large(kernel, device, benchmark):
x = torch.randn(512, 512, dtype=torch.float16, device=device)
y = torch.empty_like(x)

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass
from pathlib import Path
import pytest
from kernels import load_kernel
from kernels.cli import download_kernels
@ -17,6 +19,7 @@ def test_download_all_hash_validation():
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
@pytest.mark.linux_only
def test_load_locked():
project_dir = Path(__file__).parent / "kernel_locking"
# Also validates that hashing works correctly.

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
@ -6,6 +8,7 @@ from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
@ -16,14 +19,18 @@ kernel_layer_mapping = {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
"SiluAndMulNoCompile": {
"cuda": LayerRepository(
repo_id="kernels-test/op-without-fake-test",
layer_name="SiluAndMul",
)
},
"SiluAndMulStringDevice": {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
}
@ -43,6 +50,11 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
@use_kernel_forward_from_hub("SiluAndMulNoCompile")
class SiluAndMulNoCompileKernel(SiluAndMul):
pass
@use_kernel_forward_from_hub("SiluAndMul")
class SiluAndMulWithKernel(SiluAndMul):
pass
@ -71,6 +83,7 @@ def test_arg_kinds():
assert arg_kind("foo", "bar", kwarg1="baz", kwarg2=5) == ("foo", "bar", "baz", 5)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
@ -80,7 +93,7 @@ def test_hub_forward(cls, device):
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel = kernelize(cls(), device=device)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
@ -98,11 +111,70 @@ def test_layer_fallback_works():
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda")
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
ctx = (
pytest.raises(ValueError, match="does not fulfill requirements")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
with ctx:
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
needs_torch_compile=True,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
needs_torch_compile=True,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {"SiluAndMul", "SiluAndMulStringDevice"}
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
extra_mapping1 = {
"TestKernel": {
@ -118,6 +190,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
@ -135,6 +208,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
@ -145,6 +219,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
@ -164,6 +239,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"TestKernel",
}
assert (
@ -174,6 +250,7 @@ def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
}
@ -205,6 +282,7 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@pytest.mark.linux_only
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
@ -219,25 +297,8 @@ def test_fallback_used_when_training():
linear = TorchLinear(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
# Case 1: kernel with explicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
@ -249,6 +310,7 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@ -257,6 +319,31 @@ def test_fallback_used_when_training():
linear(X)
assert linear.n_calls == 0
# Case 2: kernel with implicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 0
# Case 3: kernel out backward support should use the kernel in
# eval mode and the fallback in training. Test train ->
# eval -> train.
with use_kernel_mapping(
{
"Linear": {
@ -268,10 +355,43 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
# When switching the kernel to eval, forward gets replaced by
# the kernel.
linear.eval()
linear(X)
assert linear.n_calls == 1
## Let's do it in the other direction to make sure it works as well.
linear.train()
linear(X)
assert linear.n_calls == 2
# Case 4: same as case 3, but test eval -> train -> eval.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
linear.eval()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 2
linear.train()
linear(X)
assert linear.n_calls == 3
linear.eval()
linear(X)
assert linear.n_calls == 3