mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 05:30:30 +08:00
Compare commits
4 Commits
release-0.
...
v0.7.0
Author | SHA1 | Date | |
---|---|---|---|
b7b5f40143 | |||
b87e6fadbe | |||
fc935d9874 | |||
3622e1f8dd |
@ -53,7 +53,15 @@ Hub kernels are registered. Kernelize can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model)
|
||||
model = kernelize(model, mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
The `mode` specifies that the model will be used in inference. Similarly,
|
||||
you can ask `kernelize` to prepare the model for training:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.TRAINING)
|
||||
```
|
||||
|
||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
||||
@ -69,36 +77,37 @@ inferred (e.g. because the model has no parameters):
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, device="cuda")
|
||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
### `torch.compile`
|
||||
|
||||
Not all Hub kernels support `torch.compile`. If you want to compile a model
|
||||
after kernelizing it, pass the `needs_torch_compile` argument to ensure that
|
||||
only kernels that support `torch.compile` will be loaded:
|
||||
after kernelizing it, you need to add this to the mode. You can use the
|
||||
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, needs_torch_compile=True)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
### Fallback forward
|
||||
### Fallback `forward`
|
||||
|
||||
The `needs_torch_compile` argument will fall back to the layer's original
|
||||
`forward` if the registered kernels does not support `torch.compile`. You
|
||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||
kernel does not support backward passes or `torch.compile` respectively,
|
||||
`kernenize` will fall back to the original, non-kernelized, layer. You
|
||||
can let `kernelize` raise an exception instead by using `use_fallback=False`:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, needs_torch_compile=True, use_fallback=False)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
|
||||
```
|
||||
|
||||
This can be useful if you want to guarantee that Hub kernels are used.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
`kernelize`` relies on kernel mappings to find Hub kernels for layers.
|
||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
||||
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
|
||||
the Hub. For example:
|
||||
|
||||
@ -108,7 +117,6 @@ kernel_layer_mapping = {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
revision="layers",
|
||||
)
|
||||
}
|
||||
}
|
||||
@ -132,3 +140,58 @@ with use_kernel_mapping(kernel_layer_mapping):
|
||||
|
||||
This ensures that the mapping is not active anymore outside the
|
||||
`with`-scope.
|
||||
|
||||
### Registering kernels for specific modes
|
||||
|
||||
You might want to register two different kernels for a particular layer,
|
||||
where one kernel is optimized for a specific mode. You can do so by
|
||||
registering layer repositories for specific modes. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
|
||||
layer is used when the `mode` passed to `kernelize` is
|
||||
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
|
||||
register a kernel to be used when the mode does not match any of the
|
||||
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
|
||||
so. For example:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In this case, modes other than `Mode.INFERENCE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
|
||||
`kernels-community/activation`.
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.6.2.dev0"
|
||||
version = "0.7.0"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
@ -24,7 +24,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy == 1.14.1",
|
||||
"mypy >= 1.15.0",
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
|
@ -1,6 +1,7 @@
|
||||
from kernels.layer import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
register_kernel_mapping,
|
||||
replace_kernel_forward_from_hub,
|
||||
@ -9,6 +10,7 @@ from kernels.layer import (
|
||||
)
|
||||
from kernels.utils import (
|
||||
get_kernel,
|
||||
get_local_kernel,
|
||||
get_locked_kernel,
|
||||
has_kernel,
|
||||
install_kernel,
|
||||
@ -16,16 +18,18 @@ from kernels.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Device",
|
||||
"LayerRepository",
|
||||
"Mode",
|
||||
"get_kernel",
|
||||
"get_local_kernel",
|
||||
"get_locked_kernel",
|
||||
"has_kernel",
|
||||
"load_kernel",
|
||||
"install_kernel",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
"kernelize",
|
||||
"load_kernel",
|
||||
"register_kernel_mapping",
|
||||
"replace_kernel_forward_from_hub",
|
||||
"LayerRepository",
|
||||
"Device",
|
||||
"kernelize",
|
||||
"use_kernel_forward_from_hub",
|
||||
"use_kernel_mapping",
|
||||
]
|
||||
|
@ -1,11 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Flag, auto
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from .utils import get_kernel
|
||||
|
||||
@ -17,6 +27,41 @@ if TYPE_CHECKING:
|
||||
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
|
||||
|
||||
|
||||
class Mode(Flag):
|
||||
"""
|
||||
Kernelize mode
|
||||
|
||||
The `Mode` flag is used by `kernelize` to select kernels for the given
|
||||
mode. Mappings can be registered for specific modes.
|
||||
|
||||
* `INFERENCE`: The kernel is used for inference.
|
||||
* `TRAINING`: The kernel is used for training.
|
||||
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
|
||||
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
|
||||
matches.
|
||||
|
||||
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
|
||||
should be used for layers that are used for inference *with* `torch.compile`.
|
||||
"""
|
||||
|
||||
_NONE = 0
|
||||
DEFAULT = auto()
|
||||
TRAINING = auto()
|
||||
INFERENCE = auto()
|
||||
TORCH_COMPILE = auto()
|
||||
|
||||
def __or__(self, other: Mode) -> Mode:
|
||||
union = super().__or__(other)
|
||||
|
||||
if Mode.INFERENCE in union and Mode.TRAINING in union:
|
||||
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
|
||||
|
||||
if Mode.DEFAULT in union and union != Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
|
||||
|
||||
return union
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
type: str
|
||||
@ -59,13 +104,16 @@ class LayerRepository:
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
|
||||
"_KERNEL_MAPPING", default={}
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
|
||||
ContextVar("_KERNEL_MAPPING", default={})
|
||||
)
|
||||
|
||||
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
):
|
||||
@ -93,14 +141,17 @@ def use_kernel_mapping(
|
||||
|
||||
|
||||
def register_kernel_mapping(
|
||||
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
],
|
||||
):
|
||||
"""
|
||||
Allows one to register a mapping between a layer name the corresponding
|
||||
kernel to use, depending on the device. This should be use in conjunction
|
||||
Allows one to register a mapping between a layer name and the corresponding
|
||||
kernel(s) to use, depending on the device. This should be used in conjunction
|
||||
with `kernelize`.
|
||||
|
||||
Exemple usage:
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
from kernels import LayerRepository, register_kernel_mapping
|
||||
@ -121,10 +172,16 @@ def register_kernel_mapping(
|
||||
for new_kernel, new_device_repos in mapping.items():
|
||||
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
|
||||
for new_device, new_repo in new_device_repos.items():
|
||||
if isinstance(new_device, str):
|
||||
device_repo[Device(type=new_device)] = new_repo
|
||||
device = (
|
||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||
)
|
||||
|
||||
if isinstance(new_repo, LayerRepository):
|
||||
kernel_options = {Mode.DEFAULT: new_repo}
|
||||
else:
|
||||
device_repo[new_device] = new_repo
|
||||
kernel_options = new_repo
|
||||
|
||||
device_repo[device] = kernel_options
|
||||
|
||||
|
||||
def replace_kernel_forward_from_hub(
|
||||
@ -145,10 +202,24 @@ def replace_kernel_forward_from_hub(
|
||||
cls.kernel_layer_name = layer_name
|
||||
|
||||
|
||||
def _select_repository(
|
||||
repositories: Dict[Mode, LayerRepository],
|
||||
*,
|
||||
mode: Mode,
|
||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
||||
if mode in repositories:
|
||||
return (repositories[mode], mode)
|
||||
elif Mode.DEFAULT in repositories:
|
||||
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def kernelize(
|
||||
model: "nn.Module",
|
||||
*,
|
||||
mode: Mode,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
needs_torch_compile: bool = False,
|
||||
use_fallback: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -158,10 +229,11 @@ def kernelize(
|
||||
|
||||
Args:
|
||||
model: The PyTorch model to kernelize
|
||||
mode: the mode that the kernel is going to be used in (e.g.
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
|
||||
and `torch.compile`).
|
||||
device: The device type to load kernels for. The device type will be inferred
|
||||
from the parameters of the model when not provided.
|
||||
needs_torch_compile: When set to `true`, only kernels that support
|
||||
`torch.compile` will be loaded.
|
||||
use_fallback: Whether to use the original forward method of modules when no
|
||||
compatible kernel could be found. If set to `False`, an exception will
|
||||
be raised in such cases.
|
||||
@ -171,6 +243,15 @@ def kernelize(
|
||||
"""
|
||||
import torch
|
||||
|
||||
if mode == Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
|
||||
|
||||
# Type check ignored because this causes a false negative on Python < 3.11.
|
||||
# Looks similar to: https://github.com/python/mypy/issues/9642
|
||||
# Remove once we start doing typing checks on >= 3.11.
|
||||
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
|
||||
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
|
||||
|
||||
if device is None:
|
||||
device_type = _find_device(model)
|
||||
elif isinstance(device, str):
|
||||
@ -203,10 +284,10 @@ def kernelize(
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
# Use device type string directly instead of Device object
|
||||
repo = kernel.get(device_type)
|
||||
# Get kernel options for the device
|
||||
repos = kernel.get(device_type)
|
||||
|
||||
if repo is None:
|
||||
if repos is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||
@ -214,32 +295,35 @@ def kernelize(
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
# Short-circuit if we already loaded the layer.
|
||||
layer = _CACHED_LAYER.get(repo, None)
|
||||
if layer is not None:
|
||||
_conditionally_replace_forward(
|
||||
module=module,
|
||||
layer=layer,
|
||||
needs_torch_compile=needs_torch_compile,
|
||||
use_fallback=use_fallback,
|
||||
)
|
||||
continue
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
layer_name=repo.layer_name,
|
||||
revision=repo.revision,
|
||||
repo_with_mode = _select_repository(
|
||||
repos,
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
# Validate the replacement layer against the class layer.
|
||||
_validate_layer(check_cls=module_class, cls=layer)
|
||||
if repo_with_mode is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No repository for `{layer_name}` for configuration mode={mode}"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
_CACHED_LAYER[repo] = layer
|
||||
repo, repo_mode = repo_with_mode
|
||||
|
||||
layer = _get_layer_memoize(repo, module_class)
|
||||
|
||||
# Ideally we would do validation on the mapping where we check that
|
||||
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
|
||||
# the actual layer is compatible with that. Unfortunately, this would
|
||||
# mean that we have to pre-download everything.
|
||||
_validate_layer_has_mode(
|
||||
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
|
||||
)
|
||||
|
||||
_conditionally_replace_forward(
|
||||
module=module,
|
||||
layer=layer,
|
||||
needs_torch_compile=needs_torch_compile,
|
||||
mode=mode,
|
||||
use_fallback=use_fallback,
|
||||
)
|
||||
|
||||
@ -331,45 +415,75 @@ def _conditionally_replace_forward(
|
||||
*,
|
||||
module: "nn.Module",
|
||||
layer: Type["nn.Module"],
|
||||
needs_torch_compile: bool,
|
||||
mode: Mode,
|
||||
use_fallback: bool,
|
||||
):
|
||||
module_class = type(module)
|
||||
|
||||
# Switch to fallback when the layer does not support:
|
||||
# compilation/compile when needed.
|
||||
# backward when needed
|
||||
needs_fallback = needs_torch_compile and not getattr(
|
||||
# Switch to fallback if the mode is not supported by the layer.
|
||||
# Note that this is useful even after _validate_layer_has_mode because
|
||||
# layers registered with the DEFAULT mode never get rejected by
|
||||
# _validate_layer_has_mode. For such layers, we want to fall back in
|
||||
# case the layer does not support the given mode.
|
||||
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
|
||||
layer, "can_torch_compile", False
|
||||
)
|
||||
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
|
||||
|
||||
if needs_fallback:
|
||||
if use_fallback:
|
||||
_replace_forward(module, module_class)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Available kernel does not fulfill requirements: needs_torch_compile={needs_torch_compile}"
|
||||
)
|
||||
raise ValueError(f"Available kernel does not support mode: {mode}")
|
||||
else:
|
||||
_replace_forward(module, layer)
|
||||
|
||||
|
||||
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
|
||||
import torch.nn as nn
|
||||
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
|
||||
|
||||
module_class = type(module)
|
||||
layer_with_backward = (
|
||||
layer if getattr(layer, "has_backward", True) else module_class
|
||||
|
||||
def _validate_layer_has_mode(
|
||||
*,
|
||||
layer_name: str,
|
||||
module: Type["nn.Module"],
|
||||
repo: LayerRepository,
|
||||
repo_mode: Mode,
|
||||
):
|
||||
"""
|
||||
Check that a repository supports the mode that it was registered for.
|
||||
"""
|
||||
|
||||
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
|
||||
raise ValueError(
|
||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
|
||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||
)
|
||||
|
||||
if Mode.TORCH_COMPILE in repo_mode and not getattr(
|
||||
module, "can_torch_compile", False
|
||||
):
|
||||
raise ValueError(
|
||||
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
|
||||
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _get_layer_memoize(
|
||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
||||
) -> Type["nn.Module"]:
|
||||
layer = _CACHED_LAYER.get(repo, None)
|
||||
if layer is not None:
|
||||
return layer
|
||||
|
||||
layer = _get_kernel_layer(
|
||||
repo_id=repo.repo_id,
|
||||
layer_name=repo.layer_name,
|
||||
revision=repo.revision,
|
||||
)
|
||||
_validate_layer(check_cls=module_class, cls=layer)
|
||||
_CACHED_LAYER[repo] = layer
|
||||
|
||||
def train(self, mode: bool = True) -> nn.Module:
|
||||
super(type(self), self).train(mode)
|
||||
if mode:
|
||||
self.forward = MethodType(layer_with_backward.forward, self)
|
||||
else:
|
||||
self.forward = MethodType(layer.forward, self)
|
||||
return self
|
||||
|
||||
module.train = MethodType(train, module) # type: ignore[method-assign]
|
||||
|
||||
# Trigger setting correct forward for the current state.
|
||||
module.train(module.training)
|
||||
return layer
|
||||
|
@ -110,6 +110,23 @@ def install_kernel(
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
return _load_kernel_from_path(repo_path, package_name, variant_locks)
|
||||
except FileNotFoundError:
|
||||
# Redo with more specific error message.
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
)
|
||||
|
||||
|
||||
def _load_kernel_from_path(
|
||||
repo_path: Path,
|
||||
package_name: str,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
) -> Tuple[str, Path]:
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
||||
variant_path = repo_path / "build" / variant
|
||||
universal_variant_path = repo_path / "build" / universal_variant
|
||||
|
||||
@ -128,7 +145,7 @@ def install_kernel(
|
||||
|
||||
if not os.path.exists(module_init_path):
|
||||
raise FileNotFoundError(
|
||||
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||
f"Kernel at path `{repo_path}` does not have build: {variant}"
|
||||
)
|
||||
|
||||
return package_name, variant_path
|
||||
@ -166,10 +183,24 @@ def install_kernel_all_variants(
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
"""
|
||||
Download and import a kernel from the Hugging Face Hub.
|
||||
|
||||
The kernel is downloaded from the repository `repo_id` at
|
||||
branch/commit/tag `revision`.
|
||||
"""
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
"""
|
||||
Import a kernel from a local kernel repository path.
|
||||
"""
|
||||
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel, has_kernel
|
||||
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -9,6 +9,14 @@ def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def local_kernel():
|
||||
package_name, path = install_kernel("kernels-community/activation", "main")
|
||||
# Path is the build variant path (build/torch-<...>), so the grandparent
|
||||
# is the kernel repository path.
|
||||
return get_local_kernel(path.parent.parent, package_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metal_kernel():
|
||||
return get_kernel("kernels-test/relu-metal")
|
||||
@ -42,6 +50,22 @@ def test_gelu_fast(kernel, device):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_local_kernel(local_kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
local_kernel.gelu_fast(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
|
||||
@pytest.mark.darwin_only
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
|
||||
def test_relu_metal(metal_kernel, dtype):
|
||||
|
@ -8,6 +8,7 @@ from torch.nn import functional as F
|
||||
from kernels import (
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
register_kernel_mapping,
|
||||
use_kernel_forward_from_hub,
|
||||
@ -65,6 +66,18 @@ class SiluAndMulStringDevice(SiluAndMul):
|
||||
pass
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinearWithCounter(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
def test_arg_kinds():
|
||||
@use_kernel_forward_from_hub("ArgKind")
|
||||
class ArgKind(nn.Module):
|
||||
@ -93,7 +106,7 @@ def test_hub_forward(cls, device):
|
||||
X = torch.randn((32, 64), device=device)
|
||||
Y = silu_and_mul(X)
|
||||
|
||||
silu_and_mul_with_kernel = kernelize(cls(), device=device)
|
||||
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
|
||||
Y_kernel = silu_and_mul_with_kernel(X)
|
||||
|
||||
torch.testing.assert_close(Y_kernel, Y)
|
||||
@ -112,7 +125,7 @@ def test_layer_fallback_works():
|
||||
|
||||
# Check that we don't raise an exception for a non-existing kernel.
|
||||
silu_and_mul = SiluAndMulWithKernelFallback()
|
||||
kernelize(silu_and_mul, device="cuda")
|
||||
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@ -128,7 +141,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
|
||||
silu_and_mul_with_kernel.eval()
|
||||
|
||||
ctx = (
|
||||
pytest.raises(ValueError, match="does not fulfill requirements")
|
||||
pytest.raises(ValueError, match="does not support mode")
|
||||
if cls is SiluAndMulNoCompileKernel
|
||||
else nullcontext()
|
||||
)
|
||||
@ -136,7 +149,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
silu_and_mul_with_kernel,
|
||||
device=device,
|
||||
needs_torch_compile=True,
|
||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
use_fallback=False,
|
||||
)
|
||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||
@ -160,7 +173,7 @@ def test_torch_compile_layer_with_fallback(cls, device):
|
||||
silu_and_mul_with_kernel = kernelize(
|
||||
silu_and_mul_with_kernel,
|
||||
device=device,
|
||||
needs_torch_compile=True,
|
||||
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
)
|
||||
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
|
||||
|
||||
@ -212,7 +225,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -223,7 +238,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -232,7 +249,9 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -243,7 +262,9 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -282,20 +303,149 @@ def test_validate_kernel_layer():
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
def test_invalid_mode_for_mapping_rejected():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match="does not support backward"):
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
|
||||
|
||||
def test_kernel_modes():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: layer without further specification, becomes the
|
||||
# base layer.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: register a kernel just for training. If no base kernel
|
||||
# layer is registered, we fall back to the original layer.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Training has a kernel, so fallback.
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# No kernel for training + torch.compile, so fallback.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 3: register a kernel just for training and one for fallback.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Uses the training kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 4: register a kernel with two preferences.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# No inference kernel, so fallback.
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback.
|
||||
assert linear.n_calls == 4
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# We do have a training + torch.compile kernel.
|
||||
assert linear.n_calls == 4
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_fallback_used_when_training():
|
||||
@use_kernel_forward_from_hub("Linear")
|
||||
class TorchLinear(nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Used to check that we called hub kernel.
|
||||
self.n_calls = 0
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
self.n_calls += 1
|
||||
return super().forward(input)
|
||||
|
||||
linear = TorchLinear(32, 32).to("cuda")
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: kernel with explicit backward support should always
|
||||
# use the kernel.
|
||||
@ -310,7 +460,7 @@ def test_fallback_used_when_training():
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear)
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
@ -332,7 +482,7 @@ def test_fallback_used_when_training():
|
||||
}
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear)
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
@ -341,57 +491,18 @@ def test_fallback_used_when_training():
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 3: kernel out backward support should use the kernel in
|
||||
# eval mode and the fallback in training. Test train ->
|
||||
# eval -> train.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
def test_invalid_mode_rejected():
|
||||
with pytest.raises(ValueError, match="mutually exclusive"):
|
||||
_ = Mode.INFERENCE | Mode.TRAINING
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
||||
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="can only be used to register kernel mappings"
|
||||
):
|
||||
linear.train()
|
||||
kernelize(linear)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
|
||||
|
||||
# When switching the kernel to eval, forward gets replaced by
|
||||
# the kernel.
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
|
||||
## Let's do it in the other direction to make sure it works as well.
|
||||
linear.train()
|
||||
linear(X)
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 4: same as case 3, but test eval -> train -> eval.
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearNoBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
linear.eval()
|
||||
kernelize(linear)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 2
|
||||
|
||||
linear.train()
|
||||
linear(X)
|
||||
assert linear.n_calls == 3
|
||||
|
||||
linear.eval()
|
||||
linear(X)
|
||||
assert linear.n_calls == 3
|
||||
with pytest.raises(ValueError, match="mode must contain"):
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||
|
Reference in New Issue
Block a user