mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 13:33:48 +08:00
Compare commits
16 Commits
compile_co
...
v0.7.0
Author | SHA1 | Date | |
---|---|---|---|
b7b5f40143 | |||
b87e6fadbe | |||
fc935d9874 | |||
3622e1f8dd | |||
a7f3b2e8ed | |||
a6ab5d83ba | |||
4f9f1abfb9 | |||
f94b7780a6 | |||
bd28883775 | |||
498429e322 | |||
09c991af4b | |||
bcf8df5875 | |||
239afff6f5 | |||
c5ec6b900a | |||
3a635eaeea | |||
32ec496c5a |
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
max-parallel: 4
|
||||
matrix:
|
||||
python-version: ["3.10", "3.12"]
|
||||
torch-version: ["2.5.1", "2.6.0"]
|
||||
torch-version: ["2.6.0", "2.7.0"]
|
||||
|
||||
env:
|
||||
UV_PYTHON_PREFERENCE: only-managed
|
||||
@ -63,7 +63,7 @@ jobs:
|
||||
- name: Check README generation
|
||||
# For now, just checks that generation doesn't fail.
|
||||
run: |
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm --revision docs
|
||||
uv run kernels generate-readme kernels-community/triton-layer-norm
|
||||
|
||||
- name: Import check without torch
|
||||
run: |
|
||||
|
@ -61,4 +61,5 @@ the Hub.
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
- [Frequently Asked Questions](docs/faq.md)
|
||||
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||
|
13
docs/faq.md
Normal file
13
docs/faq.md
Normal file
@ -0,0 +1,13 @@
|
||||
# FAQ
|
||||
|
||||
## Why is the kernelization step needed?
|
||||
|
||||
In earlier versions of `kernels`, a layer's `forward` was replaced by
|
||||
`use_kernel_forward_from_hub` and `replace_kernel_forward_from_hub`. The
|
||||
new `forward` would dispatch to a kernel based on the device type,
|
||||
whether a model was training, etc. However, this approach was
|
||||
fundamentally incompatible with `torch.compile` since it relied
|
||||
on data-dependent branching.
|
||||
|
||||
To avoid branching, we have to make dispatch decisions ahead of time,
|
||||
which is what the `kernelize` function does.
|
@ -37,8 +37,14 @@ to resolve the version constraints.
|
||||
## Native Python module
|
||||
|
||||
Kernels will typically contain a native Python module with precompiled
|
||||
compute kernels and bindings. This module must fulfill the following
|
||||
requirements:
|
||||
compute kernels and bindings. This module must fulfill the requirements
|
||||
outlined in this section. For all operating systems, a kernel must not
|
||||
have dynamic library dependencies outside:
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
|
||||
### Linux
|
||||
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
@ -50,12 +56,18 @@ requirements:
|
||||
- CXXABI 1.3.11
|
||||
- GCC 7.0.0
|
||||
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
These requirement can be checked with the ABI checker (see below).
|
||||
|
||||
- No dynamic library dependencies outside:
|
||||
### macOS
|
||||
|
||||
- Torch;
|
||||
- CUDA/ROCm libraries installed as dependencies of Torch.
|
||||
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||
for compatibility with Python 3.9 and later.
|
||||
- macOS deployment target 15.0.
|
||||
- Metal 3.0 (`-std=metal3.0`).
|
||||
|
||||
The ABI3 requirement can be checked with the ABI checker (see below).
|
||||
|
||||
### ABI checker
|
||||
|
||||
The manylinux_2_28 and Python ABI 3.9 version requirements can be checked with
|
||||
[`kernel-abi-check`](https://crates.io/crates/kernel-abi-check):
|
||||
|
142
docs/layers.md
142
docs/layers.md
@ -23,33 +23,93 @@ class SiluAndMul(nn.Module):
|
||||
return F.silu(input[..., :d]) * input[..., d:]
|
||||
```
|
||||
|
||||
The decorator changes the layer, so that other implementations of the `forward`
|
||||
method can be registered using the name `SiluAndMul`.
|
||||
The decorator does not change the behavior of the class -- it annotates
|
||||
the class with the given name (here `SiluAndMul`). The `kernelize` function
|
||||
described below uses this name to look up kernels for the layer.
|
||||
|
||||
### External layers
|
||||
|
||||
An existing layer that does not (yet) have the `use_kernel_forward_from_hub`
|
||||
decorator can be made extensible by by monkeypatching it using the `replace_kernel_forward_from_hub` function.
|
||||
decorator can be made extensible using the `replace_kernel_forward_from_hub`
|
||||
function:
|
||||
|
||||
```python
|
||||
from somelibrary import SiluAndMul
|
||||
|
||||
replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
The `register_kernel_mapping` call maps the name `SiluAndMul` to actual
|
||||
hub kernels. See the [Registering a hub kernel for a layer](#registering-a-hub-kernel-for-a-layer)
|
||||
section for more information.
|
||||
|
||||
**Warning:** we strongly recommend using layers with a decorator, since
|
||||
it signifies that the maintainer intends to keep the `forward` signature
|
||||
compatible with layers from the hub.
|
||||
|
||||
## Kernelizing a model
|
||||
|
||||
A model will not use Hub kernels by default, even if it contains extensible
|
||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||
'kernelized' using the `kernelize` function. This function traverses the
|
||||
model graph and replaces the `forward` methods of extensible layers for which
|
||||
Hub kernels are registered. Kernelize can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
The `mode` specifies that the model will be used in inference. Similarly,
|
||||
you can ask `kernelize` to prepare the model for training:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.TRAINING)
|
||||
```
|
||||
|
||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
||||
itself is returned as a convenience.
|
||||
|
||||
### Kernel device
|
||||
|
||||
Kernels can be registered per device type. For instance, separate `cuda` and
|
||||
`metal` kernels could be registered for the name `SiluAndMul`. By default,
|
||||
`kernelize` will try to infer the device type from the model's parameters.
|
||||
You can pass the device type to `kernelize` if the device type cannot be
|
||||
inferred (e.g. because the model has no parameters):
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
### `torch.compile`
|
||||
|
||||
Not all Hub kernels support `torch.compile`. If you want to compile a model
|
||||
after kernelizing it, you need to add this to the mode. You can use the
|
||||
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
### Fallback `forward`
|
||||
|
||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||
kernel does not support backward passes or `torch.compile` respectively,
|
||||
`kernenize` will fall back to the original, non-kernelized, layer. You
|
||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
|
||||
```
|
||||
|
||||
This can be useful if you want to guarantee that Hub kernels are used.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
Once a layer is made extensible, users can register hub kernels for it
|
||||
by name using the `register_kernel_mapping` function. For example:
|
||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
||||
the Hub. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
@ -57,11 +117,14 @@ kernel_layer_mapping = {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
You can register such a mapping using `register_kernel_mapping`:
|
||||
|
||||
```python
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
@ -72,8 +135,63 @@ used with the `use_kernel_mapping` context manager:
|
||||
```python
|
||||
with use_kernel_mapping(kernel_layer_mapping):
|
||||
# Use the layer for which the mapping is applied.
|
||||
...
|
||||
model = kernelize(model)
|
||||
```
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
||||
|
||||
### Registering kernels for specific modes
|
||||
|
||||
You might want to register two different kernels for a particular layer,
|
||||
where one kernel is optimized for a specific mode. You can do so by
|
||||
registering layer repositories for specific modes. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
|
||||
layer is used when the `mode` passed to `kernelize` is
|
||||
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
|
||||
register a kernel to be used when the mode does not match any of the
|
||||
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
|
||||
so. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In this case, modes other than `Mode.INFERENCE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
|
||||
`kernels-community/activation`.
|
||||
|
55
flake.lock
generated
55
flake.lock
generated
@ -51,18 +51,38 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"hf-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"flake-utils": "flake-utils_2",
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1750775451,
|
||||
"narHash": "sha256-HiGqtwzIgUH7Xkh+wgpvHRZGooqrW0z663E6nauczA4=",
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"rev": "5943c3169e861618a6634bc8dbdb498e413ab9b7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"repo": "hf-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1737453259,
|
||||
"narHash": "sha256-5LaFI9SQwCZmJDasMoYMdzNouWXNk3BvjKcO19tq1Rs=",
|
||||
"lastModified": 1747820358,
|
||||
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
||||
"owner": "danieldk",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e0372dbcfd19ddd783b7c3b3868f19322f83318e",
|
||||
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "danieldk",
|
||||
"ref": "outlines-v0.1.4-tgi",
|
||||
"ref": "cudatoolkit-12.9-kernel-builder",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
@ -70,11 +90,11 @@
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"hf-nix": "hf-nix",
|
||||
"nixpkgs": [
|
||||
"tgi-nix",
|
||||
"hf-nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"tgi-nix": "tgi-nix"
|
||||
]
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
@ -106,27 +126,6 @@
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"tgi-nix": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"flake-utils": "flake-utils_2",
|
||||
"nixpkgs": "nixpkgs"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1741617161,
|
||||
"narHash": "sha256-cwKYAsIVSLtoLbG48+oi3NkSrvuZRLYs8lkJmpDsTw0=",
|
||||
"owner": "huggingface",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"rev": "5946021ec6cb6aae18158a9dc27f893cfbab2925",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "huggingface",
|
||||
"ref": "kernels-0.2.0",
|
||||
"repo": "text-generation-inference-nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
12
flake.nix
12
flake.nix
@ -1,7 +1,7 @@
|
||||
{
|
||||
inputs = {
|
||||
tgi-nix.url = "github:huggingface/text-generation-inference-nix/kernels-0.2.0";
|
||||
nixpkgs.follows = "tgi-nix/nixpkgs";
|
||||
hf-nix.url = "github:huggingface/hf-nix";
|
||||
nixpkgs.follows = "hf-nix/nixpkgs";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
outputs =
|
||||
@ -9,21 +9,21 @@
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
tgi-nix,
|
||||
hf-nix,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
inherit (tgi-nix.lib) config;
|
||||
config = hf-nix.lib.config system;
|
||||
overlays = [
|
||||
tgi-nix.overlays.default
|
||||
hf-nix.overlays.default
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = pkgs.nixfmt-rfc-style;
|
||||
formatter = pkgs.nixfmt-tree;
|
||||
devShells = with pkgs; rec {
|
||||
default = mkShell {
|
||||
buildInputs =
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.5.0"
|
||||
version = "0.7.0"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -24,7 +24,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy == 1.14.1",
|
||||
"mypy >= 1.15.0",
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
|
@ -1,6 +1,8 @@
|
||||
from kernels.layer import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
use_kernel_forward_from_hub,
|
||||
@ -8,6 +10,7 @@ from kernels.layer import (
|
||||
)
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_local_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
@ -15,15 +18,18 @@ from kernels.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Device",
|
||||
"LayerRepository",
|
||||
"Mode",
|
||||
"get_kernel",
|
||||
"get_local_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
"kernelize",
|
||||
"load_kernel",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"LayerRepository",
|
||||
"Device",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
]
|
||||
|
@ -17,6 +17,87 @@ _RE_RETURNTYPE = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _extract_description_before_tags(docstring_mdx: str) -> str:
|
||||
"""Extract the description part of a docstring before any tags."""
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1]
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
return docstring_mdx[:first_tag_pos].strip()
|
||||
else:
|
||||
return docstring_mdx.strip()
|
||||
|
||||
|
||||
def _print_parameters_section(docstring_mdx: str, *, header_level: int) -> None:
|
||||
"""Print the parameters section from a docstring."""
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _print_returns_section(
|
||||
docstring_mdx: str, *, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Print the returns section from a docstring."""
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
header = "#" * header_level
|
||||
print(f"\n{header} Returns")
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {context_name}"
|
||||
)
|
||||
print(f"\n**Type**: {returntype_matches[0][0].strip()}")
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
|
||||
|
||||
def _get_docstring(obj, use_dict_check: bool = False) -> str:
|
||||
"""Get docstring from an object, with fallback to default message."""
|
||||
# Check whether the class/method itself has docs and not just
|
||||
# the superclass.
|
||||
if use_dict_check:
|
||||
has_doc = obj.__dict__.get("__doc__", None) is not None
|
||||
else:
|
||||
has_doc = getattr(obj, "__doc__", None) is not None
|
||||
|
||||
# We use inspect.getdoc because it does normalization.
|
||||
doc = inspect.getdoc(obj)
|
||||
|
||||
return doc if has_doc and doc is not None else "No documentation available."
|
||||
|
||||
|
||||
def _process_and_print_docstring(
|
||||
docstring: str, *, kernel_name: str, context_name: str, header_level: int
|
||||
) -> None:
|
||||
"""Convert docstring to MDX and print description, parameters, and returns sections."""
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
|
||||
# Print the description
|
||||
description = _extract_description_before_tags(docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Print parameters and returns sections
|
||||
_print_parameters_section(docstring_mdx, header_level=header_level)
|
||||
_print_returns_section(
|
||||
docstring_mdx, context_name=context_name, header_level=header_level
|
||||
)
|
||||
|
||||
|
||||
def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
kernel_module = get_kernel(repo_id=repo_id, revision=revision)
|
||||
kernel_name = repo_id.split("/")[-1].replace("-", "_")
|
||||
@ -24,9 +105,10 @@ def generate_readme_for_kernel(repo_id: str, *, revision: str = "main") -> None:
|
||||
generate_metadata(kernel_module)
|
||||
generate_kernel_doc(kernel_module, kernel_name)
|
||||
generate_function_doc(kernel_module, kernel_name)
|
||||
generate_layers_doc(kernel_module, kernel_name)
|
||||
|
||||
|
||||
def generate_metadata(module: ModuleType):
|
||||
def generate_metadata(module: ModuleType) -> None:
|
||||
metadata = getattr(module, "__kernel_metadata__", {})
|
||||
if "tags" not in metadata:
|
||||
metadata["tags"] = ["kernel"]
|
||||
@ -39,7 +121,7 @@ def generate_metadata(module: ModuleType):
|
||||
print("---")
|
||||
|
||||
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
def generate_kernel_doc(module: ModuleType, kernel_name: str) -> None:
|
||||
docstring = module.__doc__.strip() if module.__doc__ is not None else None
|
||||
if docstring:
|
||||
title, rest = docstring.split("\n", 1)
|
||||
@ -49,76 +131,112 @@ def generate_kernel_doc(module: ModuleType, kernel_name: str):
|
||||
)
|
||||
|
||||
|
||||
def generate_function_doc(kernel_module, kernel_name):
|
||||
functions_info = []
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ == kernel_module.__name__:
|
||||
# Exclude private functions.
|
||||
if not name.startswith("_"):
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = inspect.getdoc(func) or "No documentation available."
|
||||
functions_info.append((name, sig, docstring))
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def generate_function_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
print("\n## Functions")
|
||||
|
||||
if not functions_info:
|
||||
print(
|
||||
"\nNo public top-level functions.",
|
||||
)
|
||||
return
|
||||
# Track if we found any functions
|
||||
found_functions = False
|
||||
|
||||
for name, func in inspect.getmembers(kernel_module, inspect.isfunction):
|
||||
# Do not include imported functions.
|
||||
if func.__module__ != kernel_module.__name__:
|
||||
continue
|
||||
|
||||
# Exclude private functions.
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
found_functions = True
|
||||
|
||||
try:
|
||||
sig = inspect.signature(func)
|
||||
docstring = _get_docstring(func)
|
||||
except ValueError:
|
||||
print(
|
||||
f"Warning: Could not retrieve signature for {name} in {kernel_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
for name, sig, docstring in functions_info:
|
||||
print(f"\n### Function `{name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
docstring, page_info={"package_name": kernel_name}
|
||||
_process_and_print_docstring(
|
||||
docstring, kernel_name=kernel_name, context_name=name, header_level=3
|
||||
)
|
||||
|
||||
params_pos = docstring_mdx.find("<parameters>")
|
||||
returns_pos = docstring_mdx.find("<returns>")
|
||||
returntype_pos = docstring_mdx.find("<returntype>")
|
||||
positions = [
|
||||
pos for pos in [params_pos, returns_pos, returntype_pos] if pos != -1
|
||||
]
|
||||
if not found_functions:
|
||||
print("\nNo public top-level functions.")
|
||||
|
||||
if positions:
|
||||
first_tag_pos = min(positions)
|
||||
# The function description is anything before the first tag.
|
||||
print(f"\n{docstring_mdx[:first_tag_pos].strip()}")
|
||||
else:
|
||||
print(f"\n{docstring_mdx.strip()}")
|
||||
|
||||
# Extract parameters
|
||||
matches = _RE_PARAMETERS.findall(docstring_mdx)
|
||||
if matches:
|
||||
print("\n### Parameters")
|
||||
for match in matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
def generate_layers_doc(kernel_module: ModuleType, kernel_name: str) -> None:
|
||||
# Check if layers module is available
|
||||
layers_module = getattr(kernel_module, "layers", None)
|
||||
if layers_module is None:
|
||||
return
|
||||
|
||||
# Extract return information
|
||||
return_matches = _RE_RETURNS.findall(docstring_mdx)
|
||||
returntype_matches = _RE_RETURNTYPE.findall(docstring_mdx)
|
||||
print("\n## Layers")
|
||||
|
||||
if return_matches or returntype_matches:
|
||||
print("\n### Returns", file=sys.stdout)
|
||||
# Track if we found any classes
|
||||
found_classes = False
|
||||
|
||||
if returntype_matches:
|
||||
if len(returntype_matches) > 1:
|
||||
raise ValueError(
|
||||
f"More than one <returntype> tag found in docstring for {name} in {kernel_module.__name__}"
|
||||
)
|
||||
for class_name, cls in inspect.getmembers(layers_module, inspect.isclass):
|
||||
# Exclude classes that were imported.
|
||||
if cls.__module__ != layers_module.__name__:
|
||||
continue
|
||||
|
||||
found_classes = True
|
||||
|
||||
try:
|
||||
# Get docstring, but not from superclasses.
|
||||
class_docstring = _get_docstring(cls, use_dict_check=True)
|
||||
except Exception:
|
||||
print(
|
||||
f"Warning: Could not retrieve documentation for class {class_name} in {layers_module.__name__}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
print(f"\n### Class `{class_name}`")
|
||||
|
||||
# Always print class description (helper handles conversion and formatting)
|
||||
class_docstring_mdx = convert_rst_docstring_to_mdx(
|
||||
class_docstring, page_info={"package_name": kernel_name}
|
||||
)
|
||||
description = _extract_description_before_tags(class_docstring_mdx)
|
||||
print(f"\n{description}")
|
||||
|
||||
# Document methods
|
||||
print("\n#### Methods")
|
||||
|
||||
for method_name, method in inspect.getmembers(cls, inspect.isfunction):
|
||||
# Note: also skip __init__, since extension layers cannot have a constructor.
|
||||
if method_name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Skip methods from superclasses.
|
||||
if method_name not in cls.__dict__:
|
||||
continue
|
||||
|
||||
try:
|
||||
sig = inspect.signature(method)
|
||||
method_docstring = _get_docstring(method)
|
||||
except ValueError:
|
||||
print(
|
||||
f"\n**Type**: {returntype_matches[0][0].strip()}", file=sys.stdout
|
||||
f"Warning: Could not retrieve signature for {method_name} in {class_name}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
if return_matches:
|
||||
for match in return_matches:
|
||||
print(f"\n{match[0].strip()}")
|
||||
print(f"\n##### Method `{method_name}`")
|
||||
print(f"\n`{sig}`")
|
||||
|
||||
_process_and_print_docstring(
|
||||
method_docstring,
|
||||
kernel_name=kernel_name,
|
||||
context_name=method_name,
|
||||
header_level=6,
|
||||
)
|
||||
|
||||
if not found_classes:
|
||||
print("\nNo layers defined.")
|
||||
|
@ -1,19 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
from enum import Flag, auto
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
class Mode(Flag):
|
||||
"""
|
||||
Kernelize mode
|
||||
|
||||
The `Mode` flag is used by `kernelize` to select kernels for the given
|
||||
mode. Mappings can be registered for specific modes.
|
||||
|
||||
* `INFERENCE`: The kernel is used for inference.
|
||||
* `TRAINING`: The kernel is used for training.
|
||||
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
|
||||
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
|
||||
matches.
|
||||
|
||||
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
|
||||
should be used for layers that are used for inference *with* `torch.compile`.
|
||||
"""
|
||||
|
||||
_NONE = 0
|
||||
DEFAULT = auto()
|
||||
TRAINING = auto()
|
||||
INFERENCE = auto()
|
||||
TORCH_COMPILE = auto()
|
||||
|
||||
def __or__(self, other: Mode) -> Mode:
|
||||
union = super().__or__(other)
|
||||
|
||||
if Mode.INFERENCE in union and Mode.TRAINING in union:
|
||||
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
|
||||
|
||||
if Mode.DEFAULT in union and union != Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
|
||||
|
||||
return union
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
type: str
|
||||
@ -53,13 +101,19 @@ class LayerRepository:
|
||||
return hash((self.layer_name, self.repo_id, self.revision))
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
|
||||
"_KERNEL_MAPPING", default={}
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
|
||||
ContextVar("_KERNEL_MAPPING", default={})
|
||||
)
|
||||
|
||||
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
):
|
||||
@ -87,12 +141,17 @@ def use_kernel_mapping(
|
||||
|
||||
|
||||
def register_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]]
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
],
|
||||
):
|
||||
"""
|
||||
Allows one to register a mapping between a layer name the corresponding kernel to use, depending on the device.
|
||||
This should be use in conjunction with `use_kernel_hub_forward` decorator on the classname.
|
||||
Exemple usage:
|
||||
Allows one to register a mapping between a layer name and the corresponding
|
||||
kernel(s) to use, depending on the device. This should be used in conjunction
|
||||
with `kernelize`.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
from kernels import LayerRepository, register_kernel_mapping
|
||||
@ -113,34 +172,106 @@ def register_kernel_mapping(
|
||||
for new_kernel, new_device_repos in mapping.items():
|
||||
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||
for new_device, new_repo in new_device_repos.items():
|
||||
if isinstance(new_device, str):
|
||||
device_repo[Device(type=new_device)] = new_repo
|
||||
device = (
|
||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||
)
|
||||
|
||||
if isinstance(new_repo, LayerRepository):
|
||||
kernel_options = {Mode.DEFAULT: new_repo}
|
||||
else:
|
||||
device_repo[new_device] = new_repo
|
||||
kernel_options = new_repo
|
||||
|
||||
device_repo[device] = kernel_options
|
||||
|
||||
|
||||
def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool = True):
|
||||
def replace_kernel_forward_from_hub(
|
||||
cls,
|
||||
layer_name: str,
|
||||
):
|
||||
"""
|
||||
Replace the forward function of a layer using a layer from the kernel hub.
|
||||
This function monkeypatches a layer, replacing the `forward` method
|
||||
of the layer with that of a layer from the hub. The replacement is done
|
||||
when a layer matching `layer_name` and device type is registered through
|
||||
`register_layer_mapping`. The device type is inferred from the first
|
||||
argument to `forward`.
|
||||
Decorator that prepares a layer class to use a kernel from the Hugging Face Hub.
|
||||
|
||||
This decorator stores the layer name and original forward method, which will be used
|
||||
by the kernelize function to replace the forward implementation with the appropriate
|
||||
kernel from the hub.
|
||||
|
||||
Args:
|
||||
cls: The layer class to decorate
|
||||
layer_name: The name of the layer to use for kernel lookup
|
||||
"""
|
||||
cls.kernel_layer_name = layer_name
|
||||
|
||||
fallback_forward = cls.forward
|
||||
|
||||
cached_layer: Dict[LayerRepository, nn.Module] = {}
|
||||
def _select_repository(
|
||||
repositories: Dict[Mode, LayerRepository],
|
||||
*,
|
||||
mode: Mode,
|
||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
||||
if mode in repositories:
|
||||
return (repositories[mode], mode)
|
||||
elif Mode.DEFAULT in repositories:
|
||||
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def kernelize(
|
||||
model: "nn.Module",
|
||||
*,
|
||||
mode: Mode,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
use_fallback: bool = True,
|
||||
):
|
||||
"""
|
||||
Iterate over all modules in the model and replace the `forward` method of
|
||||
extensible layers for which kernels are registered using `register_kernel_mapping`
|
||||
or `use_kernel_mapping`.
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to kernelize
|
||||
mode: the mode that the kernel is going to be used in (e.g.
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
|
||||
and `torch.compile`).
|
||||
device: The device type to load kernels for. The device type will be inferred
|
||||
from the parameters of the model when not provided.
|
||||
use_fallback: Whether to use the original forward method of modules when no
|
||||
compatible kernel could be found. If set to `False`, an exception will
|
||||
be raised in such cases.
|
||||
|
||||
Returns:
|
||||
The kernelized model
|
||||
"""
|
||||
import torch
|
||||
|
||||
if mode == Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
|
||||
|
||||
# Type check ignored because this causes a false negative on Python < 3.11.
|
||||
# Looks similar to: https://github.com/python/mypy/issues/9642
|
||||
# Remove once we start doing typing checks on >= 3.11.
|
||||
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
|
||||
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
||||
|
||||
if device is None:
|
||||
device_type = _find_device(model)
|
||||
elif isinstance(device, str):
|
||||
device_type = Device(type=torch.device(device).type)
|
||||
else:
|
||||
device_type = Device(device.type)
|
||||
assert isinstance(device_type, Device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
module_class = type(module)
|
||||
if not hasattr(module_class, "kernel_layer_name"):
|
||||
continue
|
||||
layer_name = module_class.kernel_layer_name
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
if _DISABLE_KERNEL_MAPPING:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
needs_backward = self.training
|
||||
is_compiling = _is_torchdynamo_compiling()
|
||||
kernel = _KERNEL_MAPPING.get().get(str(layer_name))
|
||||
|
||||
kernel = _KERNEL_MAPPING.get().get(layer_name)
|
||||
if kernel is None:
|
||||
warnings.warn(
|
||||
"\n"
|
||||
@ -150,82 +281,70 @@ def replace_kernel_forward_from_hub(cls, layer_name: str, *, use_fallback: bool
|
||||
)
|
||||
if not use_fallback:
|
||||
raise ValueError(f"No layer mapping for `{layer_name}`")
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
device = getattr(x, "device", None)
|
||||
if device is None:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
# Get kernel options for the device
|
||||
repos = kernel.get(device_type)
|
||||
|
||||
repo = kernel.get(Device(type=device.type))
|
||||
if repo is None:
|
||||
if repos is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device.type}`"
|
||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||
)
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer = cached_layer.get(repo, None)
|
||||
if layer is not None:
|
||||
# Switch to fallback when the layer does not support:
|
||||
# compilation/compile when needed.
|
||||
# backward when needed
|
||||
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
|
||||
needs_fallback |= is_compiling and not getattr(
|
||||
layer, "can_torch_compile", False
|
||||
)
|
||||
if needs_fallback:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
layer_name=repo.layer_name,
|
||||
revision=repo.revision,
|
||||
repo_with_mode = _select_repository(
|
||||
repos,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# We have to validate against the original signature.
|
||||
orig_forward = cls.forward
|
||||
try:
|
||||
cls.forward = fallback_forward
|
||||
_validate_layer(check_cls=cls, cls=layer)
|
||||
finally:
|
||||
cls.forward = orig_forward
|
||||
if repo_with_mode is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No repository for `{layer_name}` for configuration mode={mode}"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
cached_layer[repo] = layer
|
||||
repo, repo_mode = repo_with_mode
|
||||
|
||||
# Switch to fallback when the layer does not support
|
||||
# compilation/compile when needed.
|
||||
needs_fallback = needs_backward and not getattr(layer, "has_backward", True)
|
||||
needs_fallback |= is_compiling and not getattr(
|
||||
layer, "can_torch_compile", False
|
||||
layer = _get_layer_memoize(repo, module_class)
|
||||
|
||||
# Ideally we would do validation on the mapping where we check that
|
||||
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
|
||||
# the actual layer is compatible with that. Unfortunately, this would
|
||||
# mean that we have to pre-download everything.
|
||||
_validate_layer_has_mode(
|
||||
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
|
||||
)
|
||||
if needs_fallback:
|
||||
return fallback_forward(self, x, *args, **kwargs)
|
||||
|
||||
return layer.forward(self, x, *args, **kwargs)
|
||||
_conditionally_replace_forward(
|
||||
module=module,
|
||||
layer=layer,
|
||||
mode=mode,
|
||||
use_fallback=use_fallback,
|
||||
)
|
||||
|
||||
cls.forward = forward
|
||||
return model
|
||||
|
||||
|
||||
def use_kernel_forward_from_hub(layer_name: str, *, use_fallback: bool = True):
|
||||
def use_kernel_forward_from_hub(layer_name: str):
|
||||
"""
|
||||
Replace the forward function of a layer using a layer from the kernel hub.
|
||||
This decorator can be applied to a layer and replaces the forward method
|
||||
of the layer with that of a layer from the hub. The replacement is done
|
||||
when a layer matching `layer_name` and device type is registered through
|
||||
`register_layer_mapping`. The device type is inferred from the first
|
||||
argument to `forward`.
|
||||
Make a layer extensible using the name `layer_name`.
|
||||
"""
|
||||
|
||||
def decorator(cls):
|
||||
replace_kernel_forward_from_hub(cls, layer_name, use_fallback=use_fallback)
|
||||
replace_kernel_forward_from_hub(cls, layer_name)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Module":
|
||||
def _get_kernel_layer(
|
||||
*, repo_id: str, layer_name: str, revision: str
|
||||
) -> Type["nn.Module"]:
|
||||
"""Get a layer from a kernel."""
|
||||
|
||||
kernel = get_kernel(repo_id, revision=revision)
|
||||
@ -242,13 +361,13 @@ def _get_kernel_layer(*, repo_id: str, layer_name: str, revision: str) -> "nn.Mo
|
||||
|
||||
|
||||
def _validate_layer(*, check_cls, cls):
|
||||
import torch.nn as nn
|
||||
|
||||
# The layer must have at least have the following properties: (1) it
|
||||
# must be stateless; (2) the forward signature should correspond to
|
||||
# the signature it is replacing; (3) forward should not call other
|
||||
# methods.
|
||||
|
||||
from torch import nn
|
||||
|
||||
if not issubclass(cls, nn.Module):
|
||||
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
|
||||
|
||||
@ -281,17 +400,90 @@ def _validate_layer(*, check_cls, cls):
|
||||
)
|
||||
|
||||
|
||||
def _is_torchdynamo_compiling():
|
||||
# Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622)
|
||||
# hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3)
|
||||
def _find_device(model: "nn.Module") -> Device:
|
||||
try:
|
||||
import torch
|
||||
param = next(model.parameters())
|
||||
except StopIteration:
|
||||
raise ValueError(
|
||||
"Cannot determine model device, provide as `device` argument to `kernelize`."
|
||||
)
|
||||
|
||||
return torch.compiler.is_compiling()
|
||||
except Exception:
|
||||
try:
|
||||
import torch._dynamo as dynamo # noqa: F401
|
||||
return Device(type=param.device.type)
|
||||
|
||||
return dynamo.is_compiling()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _conditionally_replace_forward(
|
||||
*,
|
||||
module: "nn.Module",
|
||||
layer: Type["nn.Module"],
|
||||
mode: Mode,
|
||||
use_fallback: bool,
|
||||
):
|
||||
module_class = type(module)
|
||||
|
||||
# Switch to fallback if the mode is not supported by the layer.
|
||||
# Note that this is useful even after _validate_layer_has_mode because
|
||||
# layers registered with the DEFAULT mode never get rejected by
|
||||
# _validate_layer_has_mode. For such layers, we want to fall back in
|
||||
# case the layer does not support the given mode.
|
||||
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
|
||||
layer, "can_torch_compile", False
|
||||
)
|
||||
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
|
||||
|
||||
if needs_fallback:
|
||||
if use_fallback:
|
||||
_replace_forward(module, module_class)
|
||||
else:
|
||||
raise ValueError(f"Available kernel does not support mode: {mode}")
|
||||
else:
|
||||
_replace_forward(module, layer)
|
||||
|
||||
|
||||
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _validate_layer_has_mode(
|
||||
*,
|
||||
layer_name: str,
|
||||
module: Type["nn.Module"],
|
||||
repo: LayerRepository,
|
||||
repo_mode: Mode,
|
||||
):
|
||||
"""
|
||||
Check that a repository supports the mode that it was registered for.
|
||||
"""
|
||||
|
||||
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
|
||||
raise ValueError(
|
||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
|
||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||
)
|
||||
|
||||
if Mode.TORCH_COMPILE in repo_mode and not getattr(
|
||||
module, "can_torch_compile", False
|
||||
):
|
||||
raise ValueError(
|
||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
|
||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_layer_memoize(
|
||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
||||
) -> Type["nn.Module"]:
|
||||
layer = _CACHED_LAYER.get(repo, None)
|
||||
if layer is not None:
|
||||
return layer
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
layer_name=repo.layer_name,
|
||||
revision=repo.revision,
|
||||
)
|
||||
_validate_layer(check_cls=module_class, cls=layer)
|
||||
_CACHED_LAYER[repo] = layer
|
||||
|
||||
return layer
|
||||
|
@ -55,6 +55,7 @@ def build_variant() -> str:
|
||||
os = platform.system().lower()
|
||||
|
||||
if os == "darwin":
|
||||
cpu = "aarch64" if cpu == "arm64" else cpu
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}"
|
||||
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
@ -109,6 +110,23 @@ def install_kernel(
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
return _load_kernel_from_path(repo_path, package_name, variant_locks)
|
||||
except FileNotFoundError:
|
||||
# Redo with more specific error message.
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
)
|
||||
|
||||
|
||||
def _load_kernel_from_path(
|
||||
repo_path: Path,
|
||||
package_name: str,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
) -> Tuple[str, Path]:
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
variant_path = repo_path / "build" / variant
|
||||
universal_variant_path = repo_path / "build" / universal_variant
|
||||
|
||||
@ -127,7 +145,7 @@ def install_kernel(
|
||||
|
||||
if not os.path.exists(module_init_path):
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
f"Kernel at path `{repo_path}` does not have build: {variant}"
|
||||
)
|
||||
|
||||
return package_name, variant_path
|
||||
@ -165,10 +183,24 @@ def install_kernel_all_variants(
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
"""
|
||||
Download and import a kernel from the Hugging Face Hub.
|
||||
|
||||
The kernel is downloaded from the repository `repo_id` at
|
||||
branch/commit/tag `revision`.
|
||||
"""
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
"""
|
||||
Import a kernel from a local kernel repository path.
|
||||
"""
|
||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
|
@ -1,54 +1,82 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-community/activation",
|
||||
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
|
||||
"sha": "fd6842e88f1f23f198551d78a4541b8eb07e0538",
|
||||
"variants": {
|
||||
"torch25-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
|
||||
"hash": "sha256-61e3e51b5b59b30d4a6ba943a5e6e4ef5a9c8260cc4bca40b9fb462c0777842b",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu121-x86_64-linux": {
|
||||
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
|
||||
"hash": "sha256-baa6b872040730bd1d676c011381f6f626fb96189837b828f587c806af8994fa",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
|
||||
"hash": "sha256-c1ec7457847fa1f0e4ab43234dfc3cd0959977e03dc2ffe89b4f6b90970c7965",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
|
||||
"hash": "sha256-412f9c841f20741e42f2c6cdb8c7da0e33ab436b219975acffe18b62b97ecd7c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu121-x86_64-linux": {
|
||||
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
|
||||
"hash": "sha256-2fde7f97859506e000c1072b3916c0a75bc8cee750a9853ea8b68199e7b57bcd",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch25-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
|
||||
"hash": "sha256-93309986f39a64a5630378108154866f0545178fa8dfef9b8f8ccfef9a78608e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
|
||||
"hash": "sha256-3284d3c64b76d92c1ee930bce8013aff307f16eefb16c2d5dea9f2ca70e71e1f",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu124-x86_64-linux": {
|
||||
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
|
||||
"hash": "sha256-36a8c93773c08ddf8ef624a8a6b2866be26d1861450dfe1ecac0bed59f9ffa47",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-f5afb734520f587717665659798ff738a69e5ae1e34d4bd95624edd18fb165cd",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
|
||||
"hash": "sha256-940841a7cb44f76c9a896d8b39f5bc0e0420f1c4c05ae9423da96778de4d1f2c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu118-x86_64-linux": {
|
||||
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
|
||||
"hash": "sha256-8e0f907830c3acc8c6bebfc162c744012ff6973e8110d7bf8ecd74b492418204",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu124-x86_64-linux": {
|
||||
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
|
||||
"hash": "sha256-0833414cbe658baec55b7ff63537cddccc973fe99e3c03008cced5e66e38b6c1",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-aarch64-linux": {
|
||||
"hash": "sha256-d94fa59a13a5b623b2071aadcd1e6c8477c4d557fd06ad144f15b46b1fc71aab",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch26-cxx98-cu126-x86_64-linux": {
|
||||
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
|
||||
"hash": "sha256-64784f5f2f9e232d0f2fd824fbc47eadde505e3c232f351bead5b04c429c65c2",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu118-x86_64-linux": {
|
||||
"hash": "sha256-bcba3765f061649bac0e5a9159bea8349ced4780e24a2330aa62ce0f8d3a9d78",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-aarch64-linux": {
|
||||
"hash": "sha256-e4625df5706af025c70bd824d952b928d9a2965eeaefda72fc47be0fae680c5e",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu126-x86_64-linux": {
|
||||
"hash": "sha256-7d7d3e655f34a7b03d5603d7c1ab723ef3efc823291762421a8b3a4aa51bd405",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-aarch64-linux": {
|
||||
"hash": "sha256-60e076194dcd55b32c5aca72f09816cba0fff52f340c8a063b17ff0577154d99",
|
||||
"hash_type": "git_lfs_concat"
|
||||
},
|
||||
"torch27-cxx11-cu128-x86_64-linux": {
|
||||
"hash": "sha256-f0a3802382efdcd78b40601187a9c416579a24ef2ed5a60d2296ef0951a89597",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel, has_kernel
|
||||
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -9,6 +9,14 @@ def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel():
|
||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||
# is the kernel repository path.
|
||||
return get_local_kernel(path.parent.parent, package_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metal_kernel():
|
||||
return get_kernel("kernels-test/relu-metal")
|
||||
@ -42,6 +50,22 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_local_kernel(local_kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
local_kernel.gelu_fast(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
|
@ -1,3 +1,5 @@
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -6,6 +8,8 @@ from torch.nn import functional as F
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
register_kernel_mapping,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
@ -16,7 +20,6 @@ kernel_layer_mapping = {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
},
|
||||
"SiluAndMulNoCompile": {
|
||||
@ -29,7 +32,6 @@ kernel_layer_mapping = {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
},
|
||||
}
|
||||
@ -64,6 +66,18 @@ class SiluAndMulStringDevice(SiluAndMul):
|
||||
pass
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinearWithCounter(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
@ -92,7 +106,7 @@ def test_hub_forward(cls, device):
|
||||
X = torch.randn((32, 64), device=device)
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
@ -110,13 +124,14 @@ def test_layer_fallback_works():
|
||||
pass
|
||||
|
||||
# Check that we don't raise an exception for a non-existing kernel.
|
||||
SiluAndMulWithKernelFallback()
|
||||
silu_and_mul = SiluAndMulWithKernelFallback()
|
||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
def test_torch_compile_layer(cls, device):
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_torch_compile_layer_without_fallback(cls, device):
|
||||
silu_and_mul = SiluAndMul()
|
||||
|
||||
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
||||
@ -124,7 +139,43 @@ def test_torch_compile_layer(cls, device):
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
silu_and_mul_with_kernel.eval()
|
||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel)
|
||||
|
||||
ctx = (
|
||||
pytest.raises(ValueError, match="does not support mode")
|
||||
if cls is SiluAndMulNoCompileKernel
|
||||
else nullcontext()
|
||||
)
|
||||
with ctx:
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
silu_and_mul_with_kernel,
|
||||
device=device,
|
||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
use_fallback=False,
|
||||
)
|
||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||
|
||||
Y_compiled = silu_and_mul_compiled(X)
|
||||
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulNoCompileKernel])
|
||||
@pytest.mark.parametrize("device", ["cuda"])
|
||||
def test_torch_compile_layer_with_fallback(cls, device):
|
||||
silu_and_mul = SiluAndMul()
|
||||
|
||||
X = torch.randn((32, 64), dtype=torch.float32, device=device)
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = cls()
|
||||
silu_and_mul_with_kernel.eval()
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
silu_and_mul_with_kernel,
|
||||
device=device,
|
||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
)
|
||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||
|
||||
Y_compiled = silu_and_mul_compiled(X)
|
||||
|
||||
@ -174,7 +225,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -185,7 +238,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -194,7 +249,9 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -205,7 +262,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -244,40 +303,152 @@ def test_validate_kernel_layer():
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
def test_invalid_mode_for_mapping_rejected():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match="does not support backward"):
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
|
||||
|
||||
def test_kernel_modes():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: layer without further specification, becomes the
|
||||
# base layer.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearImplicitBackward",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: register a kernel just for training. If no base kernel
|
||||
# layer is registered, we fall back to the original layer.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Training has a kernel, so fallback.
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# No kernel for training + torch.compile, so fallback.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 3: register a kernel just for training and one for fallback.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Uses the training kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 4: register a kernel with two preferences.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# No inference kernel, so fallback.
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback.
|
||||
assert linear.n_calls == 4
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# We do have a training + torch.compile kernel.
|
||||
assert linear.n_calls == 4
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_fallback_used_when_training():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: kernel with explicit backward support should always
|
||||
# use the kernel.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
@ -289,6 +460,7 @@ def test_fallback_used_when_training():
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
@ -297,21 +469,40 @@ def test_fallback_used_when_training():
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: kernel with implicit backward support should always
|
||||
# use the kernel.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
layer_name="LinearImplicitBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
assert linear.n_calls == 0
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
assert linear.n_calls == 0
|
||||
|
||||
|
||||
def test_invalid_mode_rejected():
|
||||
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||
_ = Mode.INFERENCE | Mode.TRAINING
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
||||
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="can only be used to register kernel mappings"
|
||||
):
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
|
||||
|
||||
with pytest.raises(ValueError, match="mode must contain"):
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||
|
Reference in New Issue
Block a user