Compare commits

...

16 Commits

Author SHA1 Message Date
b7b5f40143 Set version to 0.7.0 2025-07-07 13:09:01 +00:00
b87e6fadbe Set version to 0.7.0.dev0 (#104) 2025-07-07 14:56:43 +02:00
fc935d9874 Support registering inference/training-specific layers (#103)
* Support registering inference/training-specific layers

This change makes it possible to register kernels specialized for
inference, training, and/or `torch.compile`. To do so, the mapping
notation is extended to support registering specialized kernels
for a specific 'mode'. For instance, the following mapping,

```python
kernel_layer_mapping = {
    "SiluAndMul": {
        "cuda": {
          Mode.DEFAULT: LayerRepository(
              repo_id="kernels-community/activation",
              layer_name="SiluAndMul",
          ),
          Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
              repo_id="kernels-community/activation-training-optimized",
              layer_name="SiluAndMul",
          ),
      }
    }
}
```

uses `kernels-community/activation` by default, but will switch to
using `kernels-community/activation-training-optimized` if a model
is kernelized for training and `torch.compile`.

To make it easier to add more modes in the future and to unify the
`register_kernel_mapping` and `kernelize` signatures, the `training`
and `needs_torch_compile` arguments of `kernelize` are replaced by
a single `mode` argument:

```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
```

* Documentation fixes

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Add note on when the fallback is used

* Tighten up some Mode checks

* Fix ruff check

* Attempt to fix mypy errors

* More typing fixes

* Ignore Python < 3.11 type check SNAFU

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-07-04 19:57:14 +02:00
3622e1f8dd Add get_local_kernel function (#102)
This function loads a kernel from a local repository (e.g. the output
of kernel-builder), which can be handy for testing.
2025-07-01 13:58:47 +02:00
a7f3b2e8ed Set version to 0.6.2.dev0 (#100) 2025-06-25 09:48:09 +02:00
a6ab5d83ba Make the flake work on Darwin (#98) 2025-06-24 20:35:21 +02:00
4f9f1abfb9 darwin: fix variant CPU for aarch64 (#97) 2025-06-24 20:35:07 +02:00
f94b7780a6 CI: main triton-layer-norm has docs, branch is gone (#99) 2025-06-24 16:40:36 +02:00
bd28883775 Set version to 0.6.1.dev1 (#96) 2025-06-20 11:43:26 +02:00
498429e322 Add README generation for layers (#94) 2025-06-20 10:16:50 +02:00
09c991af4b Add macOS requirements (#95) 2025-06-16 17:20:47 +02:00
bcf8df5875 Bump version to 0.6.0.dev0 (#93) 2025-06-04 13:59:32 +02:00
239afff6f5 Update Nix flake dependencies (#92)
* Update Nix flake dependencies

To ensure that we can test with Torch 2.7 kernels in the development
environment.

* Update nix fmt to use nixfmt-tree
2025-06-04 12:13:19 +02:00
c5ec6b900a Hotfix: add FAQ (#91) 2025-06-04 09:52:39 +02:00
3a635eaeea Automatic fallback for kernels that don't support training (#90)
For kernels that do not support backward, fall back to the original
implementation if `model.train(True)` is called. This removes the
need for the `needs_backward` argument of `kernelize`.
2025-06-03 19:13:57 +02:00
32ec496c5a Make the forward pass torch.compile compatible (#87)
* first commit

* style

* update

* fix

* different approach

* Polish kernelize

- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.

* Fixup tests

* Fix `torch.compile` support

* Remove some unused code

* Sync the docs

* CI: update Torch versions

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-06-03 15:06:02 +02:00
15 changed files with 991 additions and 257 deletions

View File

@ -24,7 +24,7 @@ jobs:
max-parallel: 4
matrix:
python-version: ["3.10", "3.12"]
torch-version: ["2.5.1", "2.6.0"]
torch-version: ["2.6.0", "2.7.0"]
env:
UV_PYTHON_PREFERENCE: only-managed
@ -63,7 +63,7 @@ jobs:
- name: Check README generation
# For now, just checks that generation doesn't fail.
run: |
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
uv run kernels generate-readme kernels-community/triton-layer-norm
- name: Import check without torch
run: |

View File

@ -61,4 +61,5 @@ the Hub.
- [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md)
- [Frequently Asked Questions](docs/faq.md)
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)

13
docs/faq.md Normal file
View File

@ -0,0 +1,13 @@
# FAQ
## Why is the kernelization step needed?
In earlier versions of `kernels`, a layer's `forward` was replaced by
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
new `forward` would dispatch to a kernel based on the device type,
whether a model was training, etc. However, this approach was
fundamentally incompatible with `torch.compile` since it relied
on data-dependent branching.
To avoid branching, we have to make dispatch decisions ahead of time,
which is what the `kernelize` function does.

View File

@ -37,8 +37,14 @@ to resolve the version constraints.
## Native Python module
Kernels will typically contain a native Python module with precompiled
compute kernels and bindings. This module must fulfill the following
requirements:
compute kernels and bindings. This module must fulfill the requirements
outlined in this section. For all operating systems, a kernel must not
have dynamic library dependencies outside:
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
### Linux
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
@ -50,12 +56,18 @@ requirements:
- CXXABI 1.3.11
- GCC 7.0.0
These requirement can be checked with the ABI checker (see below).
These requirement can be checked with the ABI checker (see below).
- No dynamic library dependencies outside:
### macOS
- Torch;
- CUDA/ROCm libraries installed as dependencies of Torch.
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
for compatibility with Python 3.9 and later.
- macOS deployment target 15.0.
- Metal 3.0 (`-std=metal3.0`).
The ABI3 requirement can be checked with the ABI checker (see below).
### ABI checker
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):

View File

@ -23,33 +23,93 @@ class SiluAndMul(nn.Module):
return F.silu(input[..., :d]) * input[..., d:]
```
The decorator changes the layer, so that other implementations of the `forward`
method can be registered using the name `SiluAndMul`.
The decorator does not change the behavior of the class -- it annotates
the class with the given name (here `SiluAndMul`). The `kernelize` function
described below uses this name to look up kernels for the layer.
### External layers
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
decorator can be made extensible using the `replace_kernel_forward_from_hub`
function:
```python
from somelibrary import SiluAndMul
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
register_kernel_mapping(kernel_layer_mapping)
```
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
section for more information.
**Warning:** we strongly recommend using layers with a decorator, since
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.
## Kernelizing a model
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
`metal` kernels could be registered for the name `SiluAndMul`. By default,
`kernelize` will try to infer the device type from the model's parameters.
You can pass the device type to `kernelize` if the device type cannot be
inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, you need to add this to the mode. You can use the
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
```
### Fallback `forward`
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
Once a layer is made extensible, users can register hub kernels for it
by name using the `register_kernel_mapping` function. For example:
`kernelize` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
```python
kernel_layer_mapping = {
@ -57,11 +117,14 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
```
You can register such a mapping using `register_kernel_mapping`:
```python
register_kernel_mapping(kernel_layer_mapping)
```
@ -72,8 +135,63 @@ used with the `use_kernel_mapping` context manager:
```python
with use_kernel_mapping(kernel_layer_mapping):
# Use the layer for which the mapping is applied.
...
model = kernelize(model)
```
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
layer is used when the `mode` passed to `kernelize` is
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
register a kernel to be used when the mode does not match any of the
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
so. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.

55
flake.lock generated
View File

@ -51,18 +51,38 @@
"type": "github"
}
},
"hf-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1750775451,
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
"owner": "huggingface",
"repo": "hf-nix",
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
"type": "github"
},
"original": {
"owner": "huggingface",
"repo": "hf-nix",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1737453259,
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
"lastModified": 1747820358,
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
"owner": "danieldk",
"repo": "nixpkgs",
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
"rev": "d3c1681180717528068082103bf323147de6ab0b",
"type": "github"
},
"original": {
"owner": "danieldk",
"ref": "outlines-v0.1.4-tgi",
"ref": "cudatoolkit-12.9-kernel-builder",
"repo": "nixpkgs",
"type": "github"
}
@ -70,11 +90,11 @@
"root": {
"inputs": {
"flake-utils": "flake-utils",
"hf-nix": "hf-nix",
"nixpkgs": [
"tgi-nix",
"hf-nix",
"nixpkgs"
],
"tgi-nix": "tgi-nix"
]
}
},
"systems": {
@ -106,27 +126,6 @@
"repo": "default",
"type": "github"
}
},
"tgi-nix": {
"inputs": {
"flake-compat": "flake-compat",
"flake-utils": "flake-utils_2",
"nixpkgs": "nixpkgs"
},
"locked": {
"lastModified": 1741617161,
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
"owner": "huggingface",
"repo": "text-generation-inference-nix",
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
"type": "github"
},
"original": {
"owner": "huggingface",
"ref": "kernels-0.2.0",
"repo": "text-generation-inference-nix",
"type": "github"
}
}
},
"root": "root",

View File

@ -1,7 +1,7 @@
{
inputs = {
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
nixpkgs.follows = "tgi-nix/nixpkgs";
hf-nix.url = "github:huggingface/hf-nix";
nixpkgs.follows = "hf-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
@ -9,21 +9,21 @@
self,
nixpkgs,
flake-utils,
tgi-nix,
hf-nix,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
inherit (tgi-nix.lib) config;
config = hf-nix.lib.config system;
overlays = [
tgi-nix.overlays.default
hf-nix.overlays.default
];
};
in
{
formatter = pkgs.nixfmt-rfc-style;
formatter = pkgs.nixfmt-tree;
devShells = with pkgs; rec {
default = mkShell {
buildInputs =

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.5.0"
version = "0.7.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -24,7 +24,7 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy == 1.14.1",
"mypy >= 1.15.0",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",

View File

@ -1,6 +1,8 @@
from kernels.layer import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
@ -8,6 +10,7 @@ from kernels.layer import (
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
@ -15,15 +18,18 @@ from kernels.utils import (
)
__all__ = [
"Device",
"LayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"load_kernel",
"install_kernel",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"LayerRepository",
"Device",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
]

View File

@ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile(
)
def _extract_description_before_tags(docstring_mdx: str) -> str:
"""Extract the description part of a docstring before any tags."""
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
if positions:
first_tag_pos = min(positions)
return docstring_mdx[:first_tag_pos].strip()
else:
return docstring_mdx.strip()
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
"""Print the parameters section from a docstring."""
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
header = "#" * header_level
print(f"\n{header} Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def _print_returns_section(
docstring_mdx: str, *, context_name: str, header_level: int
) -> None:
"""Print the returns section from a docstring."""
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
if return_matches or returntype_matches:
header = "#" * header_level
print(f"\n{header} Returns")
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {context_name}"
)
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
def _get_docstring(obj, use_dict_check: bool = False) -> str:
"""Get docstring from an object, with fallback to default message."""
# Check whether the class/method itself has docs and not just
# the superclass.
if use_dict_check:
has_doc = obj.__dict__.get("__doc__", None) is not None
else:
has_doc = getattr(obj, "__doc__", None) is not None
# We use inspect.getdoc because it does normalization.
doc = inspect.getdoc(obj)
return doc if has_doc and doc is not None else "No documentation available."
def _process_and_print_docstring(
docstring: str, *, kernel_name: str, context_name: str, header_level: int
) -> None:
"""Convert docstring to MDX and print description, parameters, and returns sections."""
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
)
# Print the description
description = _extract_description_before_tags(docstring_mdx)
print(f"\n{description}")
# Print parameters and returns sections
_print_parameters_section(docstring_mdx, header_level=header_level)
_print_returns_section(
docstring_mdx, context_name=context_name, header_level=header_level
)
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
kernel_name = repo_id.split("/")[-1].replace("-", "_")
@ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
generate_metadata(kernel_module)
generate_kernel_doc(kernel_module, kernel_name)
generate_function_doc(kernel_module, kernel_name)
generate_layers_doc(kernel_module, kernel_name)
def generate_metadata(module: ModuleType):
def generate_metadata(module: ModuleType) -> None:
metadata = getattr(module, "__kernel_metadata__", {})
if "tags" not in metadata:
metadata["tags"] = ["kernel"]
@ -39,7 +121,7 @@ def generate_metadata(module: ModuleType):
print("---")
def generate_kernel_doc(module: ModuleType, kernel_name: str):
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
docstring = module.__doc__.strip() if module.__doc__ is not None else None
if docstring:
title, rest = docstring.split("\n", 1)
@ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str):
)
def generate_function_doc(kernel_module, kernel_name):
functions_info = []
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ == kernel_module.__name__:
# Exclude private functions.
if not name.startswith("_"):
try:
sig = inspect.signature(func)
docstring = inspect.getdoc(func) or "No documentation available."
functions_info.append((name, sig, docstring))
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
print("\n## Functions")
if not functions_info:
print(
"\nNo public top-level functions.",
)
return
# Track if we found any functions
found_functions = False
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
# Do not include imported functions.
if func.__module__ != kernel_module.__name__:
continue
# Exclude private functions.
if name.startswith("_"):
continue
found_functions = True
try:
sig = inspect.signature(func)
docstring = _get_docstring(func)
except ValueError:
print(
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
file=sys.stderr,
)
continue
for name, sig, docstring in functions_info:
print(f"\n### Function `{name}`")
print(f"\n`{sig}`")
docstring_mdx = convert_rst_docstring_to_mdx(
docstring, page_info={"package_name": kernel_name}
_process_and_print_docstring(
docstring, kernel_name=kernel_name, context_name=name, header_level=3
)
params_pos = docstring_mdx.find("<parameters>")
returns_pos = docstring_mdx.find("<returns>")
returntype_pos = docstring_mdx.find("<returntype>")
positions = [
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
]
if not found_functions:
print("\nNo public top-level functions.")
if positions:
first_tag_pos = min(positions)
# The function description is anything before the first tag.
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
else:
print(f"\n{docstring_mdx.strip()}")
# Extract parameters
matches = _RE_PARAMETERS.findall(docstring_mdx)
if matches:
print("\n### Parameters")
for match in matches:
print(f"\n{match[0].strip()}")
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
# Check if layers module is available
layers_module = getattr(kernel_module, "layers", None)
if layers_module is None:
return
# Extract return information
return_matches = _RE_RETURNS.findall(docstring_mdx)
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
print("\n## Layers")
if return_matches or returntype_matches:
print("\n### Returns", file=sys.stdout)
# Track if we found any classes
found_classes = False
if returntype_matches:
if len(returntype_matches) > 1:
raise ValueError(
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
)
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
# Exclude classes that were imported.
if cls.__module__ != layers_module.__name__:
continue
found_classes = True
try:
# Get docstring, but not from superclasses.
class_docstring = _get_docstring(cls, use_dict_check=True)
except Exception:
print(
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
file=sys.stderr,
)
continue
print(f"\n### Class `{class_name}`")
# Always print class description (helper handles conversion and formatting)
class_docstring_mdx = convert_rst_docstring_to_mdx(
class_docstring, page_info={"package_name": kernel_name}
)
description = _extract_description_before_tags(class_docstring_mdx)
print(f"\n{description}")
# Document methods
print("\n#### Methods")
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
# Note: also skip __init__, since extension layers cannot have a constructor.
if method_name.startswith("_"):
continue
# Skip methods from superclasses.
if method_name not in cls.__dict__:
continue
try:
sig = inspect.signature(method)
method_docstring = _get_docstring(method)
except ValueError:
print(
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
file=sys.stderr,
)
continue
if return_matches:
for match in return_matches:
print(f"\n{match[0].strip()}")
print(f"\n##### Method `{method_name}`")
print(f"\n`{sig}`")
_process_and_print_docstring(
method_docstring,
kernel_name=kernel_name,
context_name=method_name,
header_level=6,
)
if not found_classes:
print("\nNo layers defined.")

View File

@ -1,19 +1,67 @@
from __future__ import annotations
import inspect
import os
import warnings
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, Union
from enum import Flag, auto
from types import MethodType
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Tuple,
Type,
Union,
)
from .utils import get_kernel
if TYPE_CHECKING:
import torch
from torch import nn
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
class Mode(Flag):
"""
Kernelize mode
The `Mode` flag is used by `kernelize` to select kernels for the given
mode. Mappings can be registered for specific modes.
* `INFERENCE`: The kernel is used for inference.
* `TRAINING`: The kernel is used for training.
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
matches.
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
should be used for layers that are used for inference *with* `torch.compile`.
"""
_NONE = 0
DEFAULT = auto()
TRAINING = auto()
INFERENCE = auto()
TORCH_COMPILE = auto()
def __or__(self, other: Mode) -> Mode:
union = super().__or__(other)
if Mode.INFERENCE in union and Mode.TRAINING in union:
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
if Mode.DEFAULT in union and union != Mode.DEFAULT:
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
return union
@dataclass(frozen=True)
class Device:
type: str
@ -53,13 +101,19 @@ class LayerRepository:
return hash((self.layer_name, self.repo_id, self.revision))
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
ContextVar("_KERNEL_MAPPING", default={})
)
def use_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
*,
inherit_mapping: bool = True,
):
@ -87,12 +141,17 @@ def use_kernel_mapping(
def register_kernel_mapping(
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
):
"""
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
Exemple usage:
Allows one to register a mapping between a layer name and the corresponding
kernel(s) to use, depending on the device. This should be used in conjunction
with `kernelize`.
Example usage:
```python
from kernels import LayerRepository, register_kernel_mapping
@ -113,34 +172,106 @@ def register_kernel_mapping(
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
device = (
Device(type=new_device) if isinstance(new_device, str) else new_device
)
if isinstance(new_repo, LayerRepository):
kernel_options = {Mode.DEFAULT: new_repo}
else:
device_repo[new_device] = new_repo
kernel_options = new_repo
device_repo[device] = kernel_options
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
def replace_kernel_forward_from_hub(
cls,
layer_name: str,
):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This function monkeypatches a layer, replacing the `forward` method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.
Args:
cls: The layer class to decorate
layer_name: The name of the layer to use for kernel lookup
"""
cls.kernel_layer_name = layer_name
fallback_forward = cls.forward
cached_layer: Dict[LayerRepository, nn.Module] = {}
def _select_repository(
repositories: Dict[Mode, LayerRepository],
*,
mode: Mode,
) -> Optional[Tuple[LayerRepository, Mode]]:
if mode in repositories:
return (repositories[mode], mode)
elif Mode.DEFAULT in repositories:
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
else:
return None
def kernelize(
model: "nn.Module",
*,
mode: Mode,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):
"""
Iterate over all modules in the model and replace the `forward` method of
extensible layers for which kernels are registered using `register_kernel_mapping`
or `use_kernel_mapping`.
Args:
model: The PyTorch model to kernelize
mode: the mode that the kernel is going to be used in (e.g.
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
and `torch.compile`).
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
Returns:
The kernelized model
"""
import torch
if mode == Mode.DEFAULT:
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
# Type check ignored because this causes a false negative on Python < 3.11.
# Looks similar to: https://github.com/python/mypy/issues/9642
# Remove once we start doing typing checks on >= 3.11.
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "kernel_layer_name"):
continue
layer_name = module_class.kernel_layer_name
def forward(self, x, *args, **kwargs):
if _DISABLE_KERNEL_MAPPING:
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
needs_backward = self.training
is_compiling = _is_torchdynamo_compiling()
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
kernel = _KERNEL_MAPPING.get().get(layer_name)
if kernel is None:
warnings.warn(
"\n"
@ -150,82 +281,70 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
)
if not use_fallback:
raise ValueError(f"No layer mapping for `{layer_name}`")
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
device = getattr(x, "device", None)
if device is None:
return fallback_forward(self, x, *args, **kwargs)
# Get kernel options for the device
repos = kernel.get(device_type)
repo = kernel.get(Device(type=device.type))
if repo is None:
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
return fallback_forward(self, x, *args, **kwargs)
_replace_forward(module, module_class)
continue
# Short-circuit if we already loaded the layer.
layer = cached_layer.get(repo, None)
if layer is not None:
# Switch to fallback when the layer does not support:
# compilation/compile when needed.
# backward when needed
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
needs_fallback |= is_compiling and not getattr(
layer, "can_torch_compile", False
)
if needs_fallback:
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
repo_with_mode = _select_repository(
repos,
mode=mode,
)
# We have to validate against the original signature.
orig_forward = cls.forward
try:
cls.forward = fallback_forward
_validate_layer(check_cls=cls, cls=layer)
finally:
cls.forward = orig_forward
if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
continue
cached_layer[repo] = layer
repo, repo_mode = repo_with_mode
# Switch to fallback when the layer does not support
# compilation/compile when needed.
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
needs_fallback |= is_compiling and not getattr(
layer, "can_torch_compile", False
layer = _get_layer_memoize(repo, module_class)
# Ideally we would do validation on the mapping where we check that
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
# the actual layer is compatible with that. Unfortunately, this would
# mean that we have to pre-download everything.
_validate_layer_has_mode(
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
)
if needs_fallback:
return fallback_forward(self, x, *args, **kwargs)
return layer.forward(self, x, *args, **kwargs)
_conditionally_replace_forward(
module=module,
layer=layer,
mode=mode,
use_fallback=use_fallback,
)
cls.forward = forward
return model
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
def use_kernel_forward_from_hub(layer_name: str):
"""
Replace the forward function of a layer using a layer from the kernel hub.
This decorator can be applied to a layer and replaces the forward method
of the layer with that of a layer from the hub. The replacement is done
when a layer matching `layer_name` and device type is registered through
`register_layer_mapping`. The device type is inferred from the first
argument to `forward`.
Make a layer extensible using the name `layer_name`.
"""
def decorator(cls):
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
replace_kernel_forward_from_hub(cls, layer_name)
return cls
return decorator
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
def _get_kernel_layer(
*, repo_id: str, layer_name: str, revision: str
) -> Type["nn.Module"]:
"""Get a layer from a kernel."""
kernel = get_kernel(repo_id, revision=revision)
@ -242,13 +361,13 @@ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Mo
def _validate_layer(*, check_cls, cls):
import torch.nn as nn
# The layer must have at least have the following properties: (1) it
# must be stateless; (2) the forward signature should correspond to
# the signature it is replacing; (3) forward should not call other
# methods.
from torch import nn
if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
@ -281,17 +400,90 @@ def _validate_layer(*, check_cls, cls):
)
def _is_torchdynamo_compiling():
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
# hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
def _find_device(model: "nn.Module") -> Device:
try:
import torch
param = next(model.parameters())
except StopIteration:
raise ValueError(
"Cannot determine model device, provide as `device` argument to `kernelize`."
)
return torch.compiler.is_compiling()
except Exception:
try:
import torch._dynamo as dynamo # noqa: F401
return Device(type=param.device.type)
return dynamo.is_compiling()
except Exception:
return False
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback if the mode is not supported by the layer.
# Note that this is useful even after _validate_layer_has_mode because
# layers registered with the DEFAULT mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
def _validate_layer_has_mode(
*,
layer_name: str,
module: Type["nn.Module"],
repo: LayerRepository,
repo_mode: Mode,
):
"""
Check that a repository supports the mode that it was registered for.
"""
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
if Mode.TORCH_COMPILE in repo_mode and not getattr(
module, "can_torch_compile", False
):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
return True
def _get_layer_memoize(
repo: LayerRepository, module_class: Type["nn.Module"]
) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
return layer
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
return layer

View File

@ -55,6 +55,7 @@ def build_variant() -> str:
os = platform.system().lower()
if os == "darwin":
cpu = "aarch64" if cpu == "arm64" else cpu
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
@ -109,6 +110,23 @@ def install_kernel(
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
@ -127,7 +145,7 @@ def install_kernel(
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
f"Kernel at path `{repo_path}` does not have build: {variant}"
)
return package_name, variant_path
@ -165,10 +183,24 @@ def install_kernel_all_variants(
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Download and import a kernel from the Hugging Face Hub.
The kernel is downloaded from the repository `repo_id` at
branch/commit/tag `revision`.
"""
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
"""
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment

View File

@ -1,54 +1,82 @@
[
{
"repo_id": "kernels-community/activation",
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
"variants": {
"torch25-cxx11-cu118-x86_64-linux": {
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu121-x86_64-linux": {
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
"hash_type": "git_lfs_concat"
},
"torch25-cxx11-cu124-x86_64-linux": {
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu118-x86_64-linux": {
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu121-x86_64-linux": {
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
"hash_type": "git_lfs_concat"
},
"torch25-cxx98-cu124-x86_64-linux": {
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu124-x86_64-linux": {
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-aarch64-linux": {
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
"hash_type": "git_lfs_concat"
},
"torch26-cxx11-cu126-x86_64-linux": {
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu118-x86_64-linux": {
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu124-x86_64-linux": {
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-aarch64-linux": {
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
"hash_type": "git_lfs_concat"
},
"torch26-cxx98-cu126-x86_64-linux": {
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu118-x86_64-linux": {
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-aarch64-linux": {
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu126-x86_64-linux": {
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-aarch64-linux": {
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
"hash_type": "git_lfs_concat"
},
"torch27-cxx11-cu128-x86_64-linux": {
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
"hash_type": "git_lfs_concat"
}
}

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel, has_kernel
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
@pytest.fixture
@ -9,6 +9,14 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@ -42,6 +50,22 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):

View File

@ -1,3 +1,5 @@
from contextlib import nullcontext
import pytest
import torch
import torch.nn as nn
@ -6,6 +8,8 @@ from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
@ -16,7 +20,6 @@ kernel_layer_mapping = {
Device(type="cuda"): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
"SiluAndMulNoCompile": {
@ -29,7 +32,6 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
},
}
@ -64,6 +66,18 @@ class SiluAndMulStringDevice(SiluAndMul):
pass
@use_kernel_forward_from_hub("Linear")
class TorchLinearWithCounter(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
@ -92,7 +106,7 @@ def test_hub_forward(cls, device):
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
@ -110,13 +124,14 @@ def test_layer_fallback_works():
pass
# Check that we don't raise an exception for a non-existing kernel.
SiluAndMulWithKernelFallback()
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_torch_compile_layer(cls, device):
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
@ -124,7 +139,43 @@ def test_torch_compile_layer(cls, device):
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel)
ctx = (
pytest.raises(ValueError, match="does not support mode")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
with ctx:
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
@pytest.mark.parametrize("device", ["cuda"])
def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), dtype=torch.float32, device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = cls()
silu_and_mul_with_kernel.eval()
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
Y_compiled = silu_and_mul_compiled(X)
@ -174,7 +225,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -185,7 +238,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -194,7 +249,9 @@ def test_mapping_contexts():
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -205,7 +262,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -244,40 +303,152 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@pytest.mark.linux_only
def test_fallback_used_when_training():
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
}
):
with pytest.raises(ValueError, match="does not support backward"):
kernelize(linear, mode=Mode.TRAINING)
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: layer without further specification, becomes the
# base layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearImplicitBackward",
layer_name="LinearBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
linear.eval()
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 4
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: kernel with explicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
@ -289,6 +460,7 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@ -297,21 +469,40 @@ def test_fallback_used_when_training():
linear(X)
assert linear.n_calls == 0
# Case 2: kernel with implicit backward support should always
# use the kernel.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
layer_name="LinearImplicitBackward",
)
}
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
assert linear.n_calls == 0
linear.eval()
linear(X)
assert linear.n_calls == 1
assert linear.n_calls == 0
def test_invalid_mode_rejected():
with pytest.raises(ValueError, match="mutually exclusive"):
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)