Compare commits

..

1 Commits

Author SHA1 Message Date
a30e82182e Set version to 0.6.2 2025-06-25 08:08:21 +00:00
9 changed files with 159 additions and 1217 deletions

View File

@ -49,30 +49,16 @@ A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
model = kernelize(model)
```
The `kernelize` function modifies the model in-place, the model
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
When the `mode` argument is not specified, the
`Mode.TRAINING | Mode.TORCH_COMPILE` mode is used as the default. This mode
aligns most closely with pure PyTorch layers (which generally support backward
passes and `torch.compile`). However, this mode can also lead to fewer
kernels being used, since not all kernels support training or `torch.compile`.
### Kernel device
Kernels can be registered per device type. For instance, separate `cuda` and
@ -83,37 +69,36 @@ inferred (e.g. because the model has no parameters):
```python
model = MyModel(...)
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
model = kernelize(model, device="cuda")
```
### `torch.compile`
Not all Hub kernels support `torch.compile`. If you want to compile a model
after kernelizing it, you need to add this to the mode. You can use the
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
after kernelizing it, pass the `needs_torch_compile` argument to ensure that
only kernels that support `torch.compile` will be loaded:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
model = kernelize(model, needs_torch_compile=True)
```
### Fallback `forward`
### Fallback forward
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
kernel does not support backward passes or `torch.compile` respectively,
`kernenize` will fall back to the original, non-kernelized, layer. You
The `needs_torch_compile` argument will fall back to the layer's original
`forward` if the registered kernels does not support `torch.compile`. You
can let `kernelize` raise an exception instead by using `use_fallback=False`:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=False)
model = kernelize(model, needs_torch_compile=True, use_fallback=False)
```
This can be useful if you want to guarantee that Hub kernels are used.
## Registering a hub kernel for a layer
`kernelize` relies on kernel mappings to find Hub kernels for layers.
`kernelize`` relies on kernel mappings to find Hub kernels for layers.
Kernel mappings map a kernel name such as `SiluAndMul` to a kernel on
the Hub. For example:
@ -123,6 +108,7 @@ kernel_layer_mapping = {
"cuda": LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
revision="layers",
)
}
}
@ -146,133 +132,3 @@ with use_kernel_mapping(kernel_layer_mapping):
This ensures that the mapping is not active anymore outside the
`with`-scope.
### Registering kernels for specific modes
You might want to register two different kernels for a particular layer,
where one kernel is optimized for a specific mode. You can do so by
registering layer repositories for specific modes. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
layer is used when the `mode` passed to `kernelize` is
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
register a kernel to be used when the mode does not match any of the
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
so. For example:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.
### Mode fallback behavior
As described above, if there is no exact match for the mode given to
`kernelize`, it will try to use the kernel registered for `Mode.DEFAULT`.
If the `Mode.DEFAULT` kernel does not support the `kernelize` mode, the
original layer's `forward` method will be used instead.
As an example, suppose that two kernels were registered for a layer:
1. Kernel `A` is registered for `Mode.DEFAULT`. This kernel supports training
(backward), but not `torch.compile`.
2. Kernel `B` is registered for `Mode.INFERENCE | Mode.COMPILE` and supports
`torch.compile`.
`kernelize` modes will then behave as follows:
- `Mode.INFERENCE | Mode.COMPILE`` uses kernel `B`: exact match.
- `Mode.INFERENCE` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`.
- `Mode.TRAIN` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`, which supports training.
- `Mode.TRAIN | Mode.COMPILE`: uses the original layer's
`forward`: no exact match, falling back to `Mode.DEFAULT` is not possible
because kernel `A` does not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:
```python
kernel_layer_mapping = {
"SiluAndMul": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=89
),
): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Device(
type="cuda",
properties=CUDAProperties(
min_capability=90, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-community/activation-hopper",
layer_name="SiluAndMul",
),
}
}
```
Capabilities behave as follows:
- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6.
Then `KernelA` will be used because the interval 80..89 is smaller
than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.**

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.7.0.dev0"
version = "0.6.2"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
@ -24,7 +24,7 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"mypy >= 1.15.0",
"mypy == 1.14.1",
"pytest >=8",
# Whatever version is compatible with pytest.
"pytest-benchmark",

View File

@ -1,8 +1,6 @@
from kernels.layer import (
CUDAProperties,
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
replace_kernel_forward_from_hub,
@ -11,7 +9,6 @@ from kernels.layer import (
)
from kernels.utils import (
get_kernel,
get_local_kernel,
get_locked_kernel,
has_kernel,
install_kernel,
@ -19,19 +16,16 @@ from kernels.utils import (
)
__all__ = [
"CUDAProperties",
"Device",
"LayerRepository",
"Mode",
"get_kernel",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
"install_kernel",
"kernelize",
"load_kernel",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"install_kernel",
"use_kernel_forward_from_hub",
"use_kernel_mapping",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"LayerRepository",
"Device",
"kernelize",
]

View File

@ -1,200 +0,0 @@
# AVL-balanced interval trees. We could use the intervaltree
# packages, but it seems unmaintained and does not have type
# annotations.
from typing import Generic, List, Optional, Tuple, TypeVar
T = TypeVar("T")
class _Node(Generic[T]):
"""A node in the interval tree."""
def __init__(self, start: int, end: int, data: T):
self.start: int = start
self.end: int = end
self.data: T = data
self.max_end: int = end
self.left: Optional["_Node[T]"] = None
self.right: Optional["_Node[T]"] = None
self.height: int = 1
def __repr__(self) -> str:
return f"Node({self.start}, {self.end})"
class IntervalTree(Generic[T]):
"""A data structure to hold and query (unique) intervals."""
root: Optional[_Node[T]]
def __init__(self):
self.root = None
def insert(self, start: int, end: int, data: T) -> None:
"""
Inserts a new interval into the tree.
Args:
start: The starting point of the interval.
end: The ending point of the interval.
data: The data associated with this interval.
"""
self.root = self._insert(self.root, start, end, data)
def _get_height(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return node.height
def _get_balance(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return self._get_height(node.left) - self._get_height(node.right)
def _update_node_attributes(self, node: _Node[T]) -> None:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
node.max_end = node.end
if node.left:
node.max_end = max(node.max_end, node.left.max_end)
if node.right:
node.max_end = max(node.max_end, node.right.max_end)
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
"""Performs a right rotation."""
x = y.left
assert x is not None
T2 = x.right
x.right = y
y.left = T2
self._update_node_attributes(y)
self._update_node_attributes(x)
return x
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
"""Performs a left rotation."""
y = x.right
assert y is not None
T2 = y.left
y.left = x
x.right = T2
self._update_node_attributes(x)
self._update_node_attributes(y)
return y
def _insert(
self, node: Optional[_Node[T]], start: int, end: int, data: T
) -> _Node[T]:
"""Recursive helper to insert a new node and balance the tree."""
if not node:
return _Node(start, end, data)
# Replace the data if the interval already exists.
if start == node.start and end == node.end:
node.data = data
return node
if start < node.start:
node.left = self._insert(node.left, start, end, data)
else:
node.right = self._insert(node.right, start, end, data)
self._update_node_attributes(node)
balance = self._get_balance(node)
# Left Left Case
if balance > 1 and node.left and start < node.left.start:
return self._right_rotate(node)
# Right Right Case
if balance < -1 and node.right and start >= node.right.start:
return self._left_rotate(node)
# Left Right Case
if balance > 1 and node.left and start >= node.left.start:
node.left = self._left_rotate(node.left)
return self._right_rotate(node)
# Right Left Case
if balance < -1 and node.right and start < node.right.start:
node.right = self._right_rotate(node.right)
return self._left_rotate(node)
return node
def search(self, point: int) -> List[T]:
"""
Searches for all intervals that contain the given point.
Args:
point: The point to search for.
Returns:
A list of data items from all matching intervals.
"""
results: List[T] = []
self._search(self.root, point, results)
return results
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
"""Recursive helper to find all overlapping intervals."""
if node is None or point > node.max_end:
return
if node.left:
self._search(node.left, point, results)
if node.start <= point <= node.end:
results.append(node.data)
if point >= node.start and node.right:
self._search(node.right, point, results)
def find_smallest_interval(self, point: int) -> Optional[T]:
"""
Finds the item with the most specific (smallest) range for a given point.
Args:
point: The capability to look up.
Returns:
The data of the best-matching item, or None if no match is found.
"""
matches: List[Tuple[int, int, T]] = []
self._find_with_intervals(self.root, point, matches)
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# is just to ensure that we can compare against a trivial
# implementation in tests.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def _find_with_intervals(
self,
node: Optional[_Node[T]],
point: int,
results: List[Tuple[int, int, T]],
) -> None:
"""A modified search that collects interval ranges along with data."""
if node is None or point > node.max_end:
return
if node.left:
self._find_with_intervals(node.left, point, results)
if node.start <= point <= node.end:
results.append((node.start, node.end, node.data))
if point >= node.start and node.right:
self._find_with_intervals(node.right, point, results)

View File

@ -1,26 +1,12 @@
from __future__ import annotations
import inspect
import os
import sys
import warnings
from abc import ABC, abstractmethod
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Flag, auto
from functools import lru_cache
from types import MethodType
from typing import (
TYPE_CHECKING,
Dict,
Optional,
Tuple,
Type,
Union,
)
from typing import TYPE_CHECKING, Dict, Optional, Type, Union
from ._interval_tree import IntervalTree
from .utils import get_kernel
if TYPE_CHECKING:
@ -31,84 +17,17 @@ if TYPE_CHECKING:
_DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0")))
class Mode(Flag):
"""
Kernelize mode
The `Mode` flag is used by `kernelize` to select kernels for the given
mode. Mappings can be registered for specific modes.
* `INFERENCE`: The kernel is used for inference.
* `TRAINING`: The kernel is used for training.
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
matches.
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
should be used for layers that are used for inference *with* `torch.compile`.
"""
_NONE = 0
DEFAULT = auto()
TRAINING = auto()
INFERENCE = auto()
TORCH_COMPILE = auto()
def __or__(self, other: Mode) -> Mode:
union = super().__or__(other)
if Mode.INFERENCE in union and Mode.TRAINING in union:
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
if Mode.DEFAULT in union and union != Mode.DEFAULT:
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
return union
@dataclass(frozen=True)
class Device:
type: str
properties: Optional[CUDAProperties] = None
def __post_init__(self):
if self.properties is not None and isinstance(self.properties, CUDAProperties):
if self.type != "cuda":
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
def create_repo(self) -> _DeviceRepos:
"""Create an appropriate repository set for this device type."""
if self.type == "cuda":
return _CUDARepos()
elif self.type == "mps":
return _MPSRepos()
else:
raise ValueError(f"Unknown device type: {self.type}")
# In the future we might add compute capabilities, etc.
def __eq__(self, other):
if not isinstance(other, Device):
return NotImplemented
return self.type == other.type and self.properties == other.properties
return isinstance(other, Device) and self.type == other.type
def __hash__(self):
return hash((self.type, self.properties))
@dataclass(frozen=True)
class CUDAProperties:
min_capability: int
max_capability: int
def __eq__(self, other):
if not isinstance(other, CUDAProperties):
return NotImplemented
return (
self.min_capability == other.min_capability
and self.max_capability == other.max_capability
)
def __hash__(self):
return hash((self.min_capability, self.max_capability))
return hash(self.type)
@dataclass
@ -140,86 +59,13 @@ class LayerRepository:
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
class _DeviceRepos(ABC):
"""
Device-specific kernel layer repositories.
"""
@property
@abstractmethod
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]: ...
@abstractmethod
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
"""
Insert a repository for a specific device and mode.
"""
...
class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepository]
def __init__(self):
super().__init__()
self._repos = {}
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]:
return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
if device.type != "mps":
raise ValueError(f"Device type must be 'mps', got {device.type}")
self._repos = repos
class _CUDARepos(_DeviceRepos):
_repos: IntervalTree[Dict[Mode, LayerRepository]]
def __init__(self):
super().__init__()
self.repos_by_capability = IntervalTree()
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]:
capability = _find_capability()
return self.repos_by_capability.find_smallest_interval(capability)
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
assert device.properties is None or isinstance(
device.properties, CUDAProperties
)
min_capability = (
0 if device.properties is None else device.properties.min_capability
)
max_capability = (
sys.maxsize
if device.properties is None
else device.properties.max_capability
)
self.repos_by_capability.insert(min_capability, max_capability, repos)
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, LayerRepository]]] = ContextVar(
"_KERNEL_MAPPING", default={}
)
def use_kernel_mapping(
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
*,
inherit_mapping: bool = True,
):
@ -247,17 +93,14 @@ def use_kernel_mapping(
def register_kernel_mapping(
mapping: Dict[
str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
],
mapping: Dict[str, Dict[Union[Device, str], LayerRepository]],
):
"""
Allows one to register a mapping between a layer name and the corresponding
kernel(s) to use, depending on the device. This should be used in conjunction
Allows one to register a mapping between a layer name the corresponding
kernel to use, depending on the device. This should be use in conjunction
with `kernelize`.
Example usage:
Exemple usage:
```python
from kernels import LayerRepository, register_kernel_mapping
@ -278,17 +121,10 @@ def register_kernel_mapping(
for new_kernel, new_device_repos in mapping.items():
device_repo = _KERNEL_MAPPING.get().setdefault(new_kernel, {})
for new_device, new_repo in new_device_repos.items():
device = (
Device(type=new_device) if isinstance(new_device, str) else new_device
)
if isinstance(new_repo, LayerRepository):
kernel_options = {Mode.DEFAULT: new_repo}
if isinstance(new_device, str):
device_repo[Device(type=new_device)] = new_repo
else:
kernel_options = new_repo
feature_repos = device_repo.setdefault(device.type, device.create_repo())
feature_repos.insert(device, kernel_options)
device_repo[new_device] = new_repo
def replace_kernel_forward_from_hub(
@ -309,24 +145,10 @@ def replace_kernel_forward_from_hub(
cls.kernel_layer_name = layer_name
def _select_repository(
repositories: Dict[Mode, LayerRepository],
*,
mode: Mode,
) -> Optional[Tuple[LayerRepository, Mode]]:
if mode in repositories:
return (repositories[mode], mode)
elif Mode.DEFAULT in repositories:
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
else:
return None
def kernelize(
model: "nn.Module",
*,
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
device: Optional[Union[str, "torch.device"]] = None,
needs_torch_compile: bool = False,
use_fallback: bool = True,
):
"""
@ -336,11 +158,10 @@ def kernelize(
Args:
model: The PyTorch model to kernelize
mode: the mode that the kernel is going to be used in (e.g.
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training
and `torch.compile`).
device: The device type to load kernels for. The device type will be inferred
from the parameters of the model when not provided.
needs_torch_compile: When set to `true`, only kernels that support
`torch.compile` will be loaded.
use_fallback: Whether to use the original forward method of modules when no
compatible kernel could be found. If set to `False`, an exception will
be raised in such cases.
@ -350,22 +171,12 @@ def kernelize(
"""
import torch
if mode == Mode.DEFAULT:
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
# Type check ignored because this causes a false negative on Python < 3.11.
# Looks similar to: https://github.com/python/mypy/issues/9642
# Remove once we start doing typing checks on >= 3.11.
if Mode.INFERENCE not in mode and Mode.TRAINING not in mode: # type: ignore[operator]
raise ValueError("kernelize mode must contain Mode.INFERENCE or Mode.TRAINING.")
if device is None:
device_type = _find_device(model)
elif isinstance(device, str):
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
@ -392,10 +203,10 @@ def kernelize(
_replace_forward(module, module_class)
continue
# Get kernel options for the device
property_repos = kernel.get(device_type.type)
# Use device type string directly instead of Device object
repo = kernel.get(device_type)
if property_repos is None:
if repo is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
@ -403,45 +214,32 @@ def kernelize(
_replace_forward(module, module_class)
continue
repos = property_repos.repos
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
)
_replace_forward(module, module_class)
# Short-circuit if we already loaded the layer.
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
_conditionally_replace_forward(
module=module,
layer=layer,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
continue
repo_with_mode = _select_repository(
repos,
mode=mode,
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
)
if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{layer_name}` for configuration mode={mode}"
)
_replace_forward(module, module_class)
continue
# Validate the replacement layer against the class layer.
_validate_layer(check_cls=module_class, cls=layer)
repo, repo_mode = repo_with_mode
layer = _get_layer_memoize(repo, module_class)
# Ideally we would do validation on the mapping where we check that
# e.g. if a repo class is registered for TRAINING | TORCH_COMPILE,
# the actual layer is compatible with that. Unfortunately, this would
# mean that we have to pre-download everything.
_validate_layer_has_mode(
layer_name=layer_name, module=layer, repo=repo, repo_mode=repo_mode
)
_CACHED_LAYER[repo] = layer
_conditionally_replace_forward(
module=module,
layer=layer,
mode=mode,
needs_torch_compile=needs_torch_compile,
use_fallback=use_fallback,
)
@ -529,87 +327,49 @@ def _find_device(model: "nn.Module") -> Device:
return Device(type=param.device.type)
@lru_cache
def _find_capability() -> int:
import torch
major, minor = torch.cuda.get_device_capability(device=None)
return major * 10 + minor
def _conditionally_replace_forward(
*,
module: "nn.Module",
layer: Type["nn.Module"],
mode: Mode,
needs_torch_compile: bool,
use_fallback: bool,
):
module_class = type(module)
# Switch to fallback if the mode is not supported by the layer.
# Note that this is useful even after _validate_layer_has_mode because
# layers registered with the DEFAULT mode never get rejected by
# _validate_layer_has_mode. For such layers, we want to fall back in
# case the layer does not support the given mode.
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
# Switch to fallback when the layer does not support:
# compilation/compile when needed.
# backward when needed
needs_fallback = needs_torch_compile and not getattr(
layer, "can_torch_compile", False
)
needs_fallback |= Mode.TRAINING in mode and not getattr(layer, "has_backward", True)
if needs_fallback:
if use_fallback:
_replace_forward(module, module_class)
else:
raise ValueError(f"Available kernel does not support mode: {mode}")
raise ValueError(
f"Available kernel does not fulfill requirements: needs_torch_compile={needs_torch_compile}"
)
else:
_replace_forward(module, layer)
def _replace_forward(module: "nn.Module", layer: Type["nn.Module"]):
module.forward = MethodType(layer.forward, module) # type: ignore[method-assign]
import torch.nn as nn
def _validate_layer_has_mode(
*,
layer_name: str,
module: Type["nn.Module"],
repo: LayerRepository,
repo_mode: Mode,
):
"""
Check that a repository supports the mode that it was registered for.
"""
if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support backward.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
if Mode.TORCH_COMPILE in repo_mode and not getattr(
module, "can_torch_compile", False
):
raise ValueError(
f"Layer `{repo.layer_name}` ({repo.repo_id}, revision: {repo.revision}) does not support torch.compile.\n"
f"Was registered for `{layer_name}` with mode `{repo_mode}`"
)
return True
def _get_layer_memoize(
repo: LayerRepository, module_class: Type["nn.Module"]
) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None)
if layer is not None:
return layer
layer = _get_kernel_layer(
repo_id=repo.repo_id,
layer_name=repo.layer_name,
revision=repo.revision,
module_class = type(module)
layer_with_backward = (
layer if getattr(layer, "has_backward", True) else module_class
)
_validate_layer(check_cls=module_class, cls=layer)
_CACHED_LAYER[repo] = layer
return layer
def train(self, mode: bool = True) -> nn.Module:
super(type(self), self).train(mode)
if mode:
self.forward = MethodType(layer_with_backward.forward, self)
else:
self.forward = MethodType(layer.forward, self)
return self
module.train = MethodType(train, module) # type: ignore[method-assign]
# Trigger setting correct forward for the current state.
module.train(module.training)

View File

@ -110,23 +110,6 @@ def install_kernel(
)
)
try:
return _load_kernel_from_path(repo_path, package_name, variant_locks)
except FileNotFoundError:
# Redo with more specific error message.
raise FileNotFoundError(
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
def _load_kernel_from_path(
repo_path: Path,
package_name: str,
variant_locks: Optional[Dict[str, VariantLock]] = None,
) -> Tuple[str, Path]:
variant = build_variant()
universal_variant = universal_build_variant()
variant_path = repo_path / "build" / variant
universal_variant_path = repo_path / "build" / universal_variant
@ -145,7 +128,7 @@ def _load_kernel_from_path(
if not os.path.exists(module_init_path):
raise FileNotFoundError(
f"Kernel at path `{repo_path}` does not have build: {variant}"
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
)
return package_name, variant_path
@ -183,24 +166,10 @@ def install_kernel_all_variants(
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
"""
Download and import a kernel from the Hugging Face Hub.
The kernel is downloaded from the repository `repo_id` at
branch/commit/tag `revision`.
"""
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
"""
Import a kernel from a local kernel repository path.
"""
package_name, package_path = _load_kernel_from_path(repo_path, package_name)
return import_from_path(package_name, package_path / package_name / "__init__.py")
def has_kernel(repo_id: str, revision: str = "main") -> bool:
"""
Check whether a kernel build exists for the current environment

View File

@ -1,7 +1,7 @@
import pytest
import torch
from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
from kernels import get_kernel, has_kernel
@pytest.fixture
@ -9,14 +9,6 @@ def kernel():
return get_kernel("kernels-community/activation")
@pytest.fixture
def local_kernel():
package_name, path = install_kernel("kernels-community/activation", "main")
# Path is the build variant path (build/torch-<...>), so the grandparent
# is the kernel repository path.
return get_local_kernel(path.parent.parent, package_name)
@pytest.fixture
def metal_kernel():
return get_kernel("kernels-test/relu-metal")
@ -50,22 +42,6 @@ def test_gelu_fast(kernel, device):
assert torch.allclose(y, expected)
@pytest.mark.linux_only
def test_local_kernel(local_kernel, device):
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
y = torch.empty_like(x)
local_kernel.gelu_fast(y, x)
expected = torch.tensor(
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
device=device,
dtype=torch.float16,
)
assert torch.allclose(y, expected)
@pytest.mark.darwin_only
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_relu_metal(metal_kernel, dtype):

View File

@ -1,230 +0,0 @@
import random
from typing import Generic, List, Optional, Tuple, TypeVar
import pytest
from kernels._interval_tree import IntervalTree, _Node
T = TypeVar("T")
class SimpleIntervalStore(Generic[T]):
"""A simple O(n) implementation that stores intervals in a list."""
def __init__(self):
self.intervals: List[Tuple[int, int, T]] = []
def insert(self, start: int, end: int, data: T) -> None:
"""Insert an interval into the store."""
# Replace data if the interval already exists.
for i, (existing_start, existing_end, existing_data) in enumerate(
self.intervals
):
if existing_start == start and existing_end == end:
self.intervals[i] = (start, end, data)
return
self.intervals.append((start, end, data))
def find_smallest_interval(self, point: int) -> Optional[T]:
"""Find the best match using linear search."""
matches = []
for start, end, data in self.intervals:
if start <= point <= end:
matches.append((start, end, data))
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# mirrors the ordering in the intervan tree.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def is_balanced(tree: IntervalTree[T]) -> bool:
"""Check if the AVL tree is properly balanced."""
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
if node is None:
return True, 0
# Left and right subtrees should be balanced.
left_balanced, left_height = check_balance(node.left)
if not left_balanced:
return False, -1
right_balanced, right_height = check_balance(node.right)
if not right_balanced:
return False, -1
# The difference in height should not exceed 1.
if abs(left_height - right_height) > 1:
return False, -1
# Check if the height is correct.
expected_height = 1 + max(left_height, right_height)
if node.height != expected_height:
return False, -1
return True, expected_height
balanced, _ = check_balance(tree.root)
return balanced
@pytest.fixture
def populated_tree() -> IntervalTree[str]:
"""Provides a pre-populated IntervalTree for testing."""
tree = IntervalTree[str]()
kernels = [
(80, 89, "Kernel_A_General_80_89"),
(86, 89, "Kernel_B_Ampere_86_89"),
(80, 86, "Kernel_C_Older_Ampere_80_86"),
(70, 75, "Kernel_D_Volta_70_75"),
(86, 87, "Kernel_E_Specific_86_87"),
]
for start, end, name in kernels:
tree.insert(start, end, name)
return tree
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
# Check that the smallest inteval is selected when there are
# multiple matching intervals.
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
def test_find_single_match(populated_tree):
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
def test_no_match_outside_all_ranges(populated_tree):
# Check that no interval is found when the value is out of range
# (too small/too large).
assert populated_tree.find_smallest_interval(65) is None
assert populated_tree.find_smallest_interval(95) is None
def test_no_match_in_gap_between_ranges(populated_tree):
# Check that no interval is found when the value is between two
# intervals.
assert populated_tree.find_smallest_interval(78) is None
def test_boundary_conditions_start_and_end(populated_tree):
# Test exact upper/lower bounds of intervals.
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
def test_empty_tree():
# Searching in an empty tree should return None.
empty_tree = IntervalTree[str]()
assert empty_tree.find_smallest_interval(100) is None
def test_multiple_equally_specific_matches():
# Check that we pick the match in a stable way when there is are
# multiple matching intervals with the same size.
tree = IntervalTree[str]()
str1 = "First_Narrow_Kernel"
str2 = "Second_Narrow_Kernel"
tree.insert(10, 20, "Wide_Kernel")
tree.insert(12, 17, str1)
tree.insert(14, 19, str2)
if id(str1) < id(str2):
assert tree.find_smallest_interval(15) == str1
else:
assert tree.find_smallest_interval(15) == str2
def test_property_based_interval_tree():
# Quick-check property-based testing:
#
# - Verify that the tree is balanced after each insertion.
# - Verify the query against a simple list-based implementation.
random.seed(42) # For reproducible tests
test_points = list(range(0, 101))
for _ in range(5):
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
intervals = []
for i in range(100):
start = random.randint(0, 90)
end = random.randint(start, 100)
data = f"interval_{i}_s{start}_e{end}"
intervals.append((start, end, data))
for i, (start, end, data) in enumerate(intervals):
tree.insert(start, end, data)
simple.insert(start, end, data)
# Check that tree is still balanced
assert is_balanced(
tree
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
for point in test_points:
tree_result = tree.find_smallest_interval(point)
simple_result = simple.find_smallest_interval(point)
assert tree_result == simple_result, (
f"Mismatch for point {point} after inserting {i+1} intervals. "
f"Tree: {tree_result}, Simple: {simple_result}. "
f"Last inserted: ({start}, {end})"
)
def test_property_based_edge_cases():
random.seed(123)
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
# Single-point intervals.
for i in range(10):
point = random.randint(0, 100)
data = f"single_point_{i}_{point}"
tree.insert(point, point, data)
simple.insert(point, point, data)
assert is_balanced(
tree
), f"Tree unbalanced after inserting single point {point}"
# Test the exact point and neighbors
for test_point in [point - 1, point, point + 1]:
if 0 <= test_point <= 100:
tree_result = tree.find_smallest_interval(test_point)
simple_result = simple.find_smallest_interval(test_point)
assert tree_result == simple_result
def test_unique_intervals_override():
"""Test that inserting an interval with the same start/end overrides the previous value."""
tree = IntervalTree[str]()
tree.insert(10, 20, "original_value")
assert tree.find_smallest_interval(15) == "original_value"
tree.insert(10, 20, "new_value")
assert tree.find_smallest_interval(15) == "new_value"
tree.insert(10, 25, "different_interval")
results = tree.search(15)
assert "new_value" in results
assert "different_interval" in results
assert len(results) == 2
tree.insert(10, 20, "final_value")
assert tree.find_smallest_interval(15) == "final_value"
assert is_balanced(tree)

View File

@ -1,4 +1,3 @@
import sys
from contextlib import nullcontext
import pytest
@ -9,17 +8,11 @@ from torch.nn import functional as F
from kernels import (
Device,
LayerRepository,
Mode,
kernelize,
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.layer import (
_KERNEL_MAPPING,
CUDAProperties,
_validate_layer,
use_kernel_mapping,
)
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
kernel_layer_mapping = {
"SiluAndMul": {
@ -72,18 +65,6 @@ class SiluAndMulStringDevice(SiluAndMul):
pass
@use_kernel_forward_from_hub("Linear")
class TorchLinearWithCounter(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
@ -112,7 +93,7 @@ def test_hub_forward(cls, device):
X = torch.randn((32, 64), device=device)
Y = silu_and_mul(X)
silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
silu_and_mul_with_kernel = kernelize(cls(), device=device)
Y_kernel = silu_and_mul_with_kernel(X)
torch.testing.assert_close(Y_kernel, Y)
@ -124,55 +105,6 @@ def test_hub_forward(cls, device):
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.linux_only
def test_capability():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we called out to the kernel.
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=sys.maxsize, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we didn't call out to the kernel because there is
# is no kernel with a matching capability..
assert linear.n_calls == 1
def test_layer_fallback_works():
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
class SiluAndMulWithKernelFallback(SiluAndMul):
@ -180,7 +112,7 @@ def test_layer_fallback_works():
# Check that we don't raise an exception for a non-existing kernel.
silu_and_mul = SiluAndMulWithKernelFallback()
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)
kernelize(silu_and_mul, device="cuda")
@pytest.mark.linux_only
@ -196,7 +128,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul_with_kernel.eval()
ctx = (
pytest.raises(ValueError, match="does not support mode")
pytest.raises(ValueError, match="does not fulfill requirements")
if cls is SiluAndMulNoCompileKernel
else nullcontext()
)
@ -204,7 +136,7 @@ def test_torch_compile_layer_without_fallback(cls, device):
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
needs_torch_compile=True,
use_fallback=False,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
@ -228,7 +160,7 @@ def test_torch_compile_layer_with_fallback(cls, device):
silu_and_mul_with_kernel = kernelize(
silu_and_mul_with_kernel,
device=device,
mode=Mode.INFERENCE | Mode.TORCH_COMPILE,
needs_torch_compile=True,
)
silu_and_mul_compiled = torch.compile(silu_and_mul_with_kernel, fullgraph=True)
@ -237,7 +169,6 @@ def test_torch_compile_layer_with_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
@ -281,7 +212,7 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/non-existing"
)
@ -292,7 +223,7 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/activation"
)
@ -301,7 +232,7 @@ def test_mapping_contexts():
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/non-existing"
)
@ -312,7 +243,7 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")].repo_id
== "kernels-community/activation"
)
@ -351,173 +282,20 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@pytest.mark.linux_only
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
}
):
with pytest.raises(ValueError, match="does not support backward"):
kernelize(linear, mode=Mode.TRAINING)
@pytest.mark.linux_only
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
# Case 1: layer without further specification, becomes the
# base layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING)
linear(X)
assert linear.n_calls == 0
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
assert linear.n_calls == 0
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Training has a kernel, so fallback.
assert linear.n_calls == 1
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 3
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.TRAINING: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
{
"Linear": {
"cuda": {
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 5
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 5
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 5
@pytest.mark.linux_only
def test_fallback_used_when_training():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@use_kernel_forward_from_hub("Linear")
class TorchLinear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Used to check that we called hub kernel.
self.n_calls = 0
def forward(self, input: torch.Tensor) -> torch.Tensor:
self.n_calls += 1
return super().forward(input)
linear = TorchLinear(32, 32).to("cuda")
# Case 1: kernel with explicit backward support should always
# use the kernel.
@ -532,7 +310,7 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@ -554,7 +332,7 @@ def test_fallback_used_when_training():
}
):
linear.train()
kernelize(linear, mode=Mode.INFERENCE)
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 0
@ -563,18 +341,57 @@ def test_fallback_used_when_training():
linear(X)
assert linear.n_calls == 0
def test_invalid_mode_rejected():
with pytest.raises(ValueError, match="mutually exclusive"):
_ = Mode.INFERENCE | Mode.TRAINING
with pytest.raises(ValueError, match="cannot be combined with other modes"):
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
with pytest.raises(
ValueError, match="can only be used to register kernel mappings"
# Case 3: kernel out backward support should use the kernel in
# eval mode and the fallback in training. Test train ->
# eval -> train.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
linear.train()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 1
with pytest.raises(ValueError, match="mode must contain"):
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
# When switching the kernel to eval, forward gets replaced by
# the kernel.
linear.eval()
linear(X)
assert linear.n_calls == 1
## Let's do it in the other direction to make sure it works as well.
linear.train()
linear(X)
assert linear.n_calls == 2
# Case 4: same as case 3, but test eval -> train -> eval.
with use_kernel_mapping(
{
"Linear": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearNoBackward",
)
}
}
):
linear.eval()
kernelize(linear)
X = torch.randn(10, 32, device="cuda")
linear(X)
assert linear.n_calls == 2
linear.train()
linear(X)
assert linear.n_calls == 3
linear.eval()
linear(X)
assert linear.n_calls == 3