mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 21:38:52 +08:00
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
10a9686434 | |||
2d2c6b14e0 | |||
03edc573b1 | |||
c841a6c90d | |||
c7a343f195 | |||
8d838f947d |
151
docs/layers.md
151
docs/layers.md
@ -49,23 +49,46 @@ A model will not use Hub kernels by default, even if it contains extensible
|
||||
layers. To enable the use of Hub kernels in the model, it needs to be
|
||||
'kernelized' using the `kernelize` function. This function traverses the
|
||||
model graph and replaces the `forward` methods of extensible layers for which
|
||||
Hub kernels are registered. Kernelize can be used as follows:
|
||||
Hub kernels are registered. `kernelize` can be used as follows:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
The `mode` specifies that the model will be used in inference. Similarly,
|
||||
you can ask `kernelize` to prepare the model for training:
|
||||
The `kernelize` function modifies the model in-place, the model itself is
|
||||
returned as a convenience. The `mode` specifies that the model will be used
|
||||
in inference. Similarly, you can ask `kernelize` to prepare the model for
|
||||
training:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.TRAINING)
|
||||
```
|
||||
|
||||
**Note:** the `kernelize` function modifies the model in-place, the model
|
||||
itself is returned as a convenience.
|
||||
A model that is kernelized for training can also be used for inference, but
|
||||
not the other way around. If you want to change the mode of the kernelized
|
||||
model, you can just run `kernelize` on the model again with the new mode.
|
||||
|
||||
If you want to compile a model with `torch.compile`, this should be indicated
|
||||
in the mode as well. You can do this by combining `Mode.INFERENCE` or
|
||||
`Mode.TRAINING` with `Mode.TORCH_COMPILE` using the set union (`|`) operator:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
|
||||
# Inference
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
|
||||
# Training
|
||||
model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
When the `mode` argument is not specified,
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` is used as the default. This mode
|
||||
aligns most closely with pure PyTorch layers which also support training
|
||||
and `torch.compile`. However, to select the most performant kernels, it
|
||||
is often good to make the mode specific as possible.
|
||||
|
||||
### Kernel device
|
||||
|
||||
@ -80,17 +103,6 @@ model = MyModel(...)
|
||||
model = kernelize(model, device="cuda", mode=Mode.INFERENCE)
|
||||
```
|
||||
|
||||
### `torch.compile`
|
||||
|
||||
Not all Hub kernels support `torch.compile`. If you want to compile a model
|
||||
after kernelizing it, you need to add this to the mode. You can use the
|
||||
set union (`|`) operator to add `TORCH_COMPILE` to the mode:
|
||||
|
||||
```python
|
||||
model = MyModel(...)
|
||||
model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
```
|
||||
|
||||
### Fallback `forward`
|
||||
|
||||
If the `TRAINING` and/or `TORCH_COMPILE` modes are used, but a registered
|
||||
@ -105,6 +117,12 @@ model = kernelize(model, mode=Mode.INFERENCE | Mode.TORCH_COMPILE, use_fallback=
|
||||
|
||||
This can be useful if you want to guarantee that Hub kernels are used.
|
||||
|
||||
### Inspecting kernels which kernels are used
|
||||
|
||||
The kernels that are used are logged at the `INFO` level by `kernelize`.
|
||||
See the [Python logging](https://docs.python.org/3/library/logging.html)
|
||||
documentation for information on how to configure logging.
|
||||
|
||||
## Registering a hub kernel for a layer
|
||||
|
||||
`kernelize` relies on kernel mappings to find Hub kernels for layers.
|
||||
@ -164,34 +182,91 @@ kernel_layer_mapping = {
|
||||
}
|
||||
```
|
||||
|
||||
The kernels will match exactly on the mode. So, for instance in the example above, no kernel
|
||||
layer is used when the `mode` passed to `kernelize` is
|
||||
`Mode.INFERENCE | Mode.TORCH_COMPILE` or `Mode.TRAINING`. However, if you want to
|
||||
register a kernel to be used when the mode does not match any of the
|
||||
modes in the mapping, you can use the special `Mode.DEFAULT` mode to do
|
||||
so. For example:
|
||||
The `kernelize` function will attempt to use the following registered
|
||||
kernels for a given mode:
|
||||
|
||||
- `INFERENCE`: `INFERENCE` → `INFERENCE | TORCH_COMPILE` → `TRAINING` →
|
||||
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `INFERENCE | TORCH_COMPILE`: `INFERENCE | TORCH_COMPILE` →
|
||||
`TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `TRAINING`: `TRAINING` → `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
- `TRAINING | TORCH_COMPILE`: `TRAINING | TORCH_COMPILE` → `FALLBACK`
|
||||
|
||||
`Mode.FALLBACK` is a special mode that is used when no other mode matches. It
|
||||
is also used when a kernel is registered without a mode, as described in the
|
||||
previous section.
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
Mode.FALLBACK: LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-community/activation-inference-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-community/activation-training-optimized",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In this case, modes other than `Mode.INFERENCE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
|
||||
`kernels-community/activation`.
|
||||
In this case, both `Mode.INFERENCE | Mode.TORCH_COMPILE` and
|
||||
`Mode.TRAINING | Mode.TORCH_COMPILE` will use the `Mode.FALLBACK` kernel,
|
||||
since the other kernels do not support `torch.compile`.
|
||||
|
||||
### Registering kernels for specific CUDA capabilities
|
||||
|
||||
Some kernels only work with newer CUDA architectures. For instance, some
|
||||
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
|
||||
supports registering layers for a range of CUDA capabilities. To do so,
|
||||
you need to register the layer for a `Device` with type `cuda` and
|
||||
set the supported range of CUDA capabilities with using `CUDAProperties`:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=75, max_capability=89
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=90, max_capability=sys.maxsize
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-community/activation-hopper",
|
||||
layer_name="SiluAndMul",
|
||||
),
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Capabilities behave as follows:
|
||||
|
||||
- The minimum and maximum capabilities are inclusive.
|
||||
- When a new kernel is registered with the same min/max capabilities as
|
||||
an existing kernel, the new kernel will replace the old kernel.
|
||||
- When there are multiple kernels that support a capability, the kernel
|
||||
with the smaller capability interval will be used. E.g. given:
|
||||
|
||||
- `KernelA` with `min_capability=80` and `max_capability=89`;
|
||||
- `KernelB` with `min_capability=75` and `max_capability=89`;
|
||||
- `kernelize` runs on a system with capability 8.6.
|
||||
|
||||
Then `KernelA` will be used because the interval 80..89 is smaller
|
||||
than 75..89. The motivation is that kernels with smaller ranges
|
||||
tend to be more optimized for a specific set of GPUs. **This behavior
|
||||
might still change in the future.**
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.7.0.dev0"
|
||||
version = "0.8.0"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
@ -1,4 +1,5 @@
|
||||
from kernels.layer import (
|
||||
CUDAProperties,
|
||||
Device,
|
||||
LayerRepository,
|
||||
Mode,
|
||||
@ -18,6 +19,7 @@ from kernels.utils import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CUDAProperties",
|
||||
"Device",
|
||||
"LayerRepository",
|
||||
"Mode",
|
||||
|
200
src/kernels/_interval_tree.py
Normal file
200
src/kernels/_interval_tree.py
Normal file
@ -0,0 +1,200 @@
|
||||
# AVL-balanced interval trees. We could use the intervaltree
|
||||
# packages, but it seems unmaintained and does not have type
|
||||
# annotations.
|
||||
|
||||
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class _Node(Generic[T]):
|
||||
"""A node in the interval tree."""
|
||||
|
||||
def __init__(self, start: int, end: int, data: T):
|
||||
self.start: int = start
|
||||
self.end: int = end
|
||||
self.data: T = data
|
||||
self.max_end: int = end
|
||||
self.left: Optional["_Node[T]"] = None
|
||||
self.right: Optional["_Node[T]"] = None
|
||||
self.height: int = 1
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node({self.start}, {self.end})"
|
||||
|
||||
|
||||
class IntervalTree(Generic[T]):
|
||||
"""A data structure to hold and query (unique) intervals."""
|
||||
|
||||
root: Optional[_Node[T]]
|
||||
|
||||
def __init__(self):
|
||||
self.root = None
|
||||
|
||||
def insert(self, start: int, end: int, data: T) -> None:
|
||||
"""
|
||||
Inserts a new interval into the tree.
|
||||
|
||||
Args:
|
||||
start: The starting point of the interval.
|
||||
end: The ending point of the interval.
|
||||
data: The data associated with this interval.
|
||||
"""
|
||||
self.root = self._insert(self.root, start, end, data)
|
||||
|
||||
def _get_height(self, node: Optional[_Node[T]]) -> int:
|
||||
if not node:
|
||||
return 0
|
||||
return node.height
|
||||
|
||||
def _get_balance(self, node: Optional[_Node[T]]) -> int:
|
||||
if not node:
|
||||
return 0
|
||||
return self._get_height(node.left) - self._get_height(node.right)
|
||||
|
||||
def _update_node_attributes(self, node: _Node[T]) -> None:
|
||||
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
|
||||
node.max_end = node.end
|
||||
if node.left:
|
||||
node.max_end = max(node.max_end, node.left.max_end)
|
||||
if node.right:
|
||||
node.max_end = max(node.max_end, node.right.max_end)
|
||||
|
||||
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
|
||||
"""Performs a right rotation."""
|
||||
x = y.left
|
||||
assert x is not None
|
||||
T2 = x.right
|
||||
|
||||
x.right = y
|
||||
y.left = T2
|
||||
|
||||
self._update_node_attributes(y)
|
||||
self._update_node_attributes(x)
|
||||
|
||||
return x
|
||||
|
||||
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
|
||||
"""Performs a left rotation."""
|
||||
y = x.right
|
||||
assert y is not None
|
||||
T2 = y.left
|
||||
|
||||
y.left = x
|
||||
x.right = T2
|
||||
|
||||
self._update_node_attributes(x)
|
||||
self._update_node_attributes(y)
|
||||
|
||||
return y
|
||||
|
||||
def _insert(
|
||||
self, node: Optional[_Node[T]], start: int, end: int, data: T
|
||||
) -> _Node[T]:
|
||||
"""Recursive helper to insert a new node and balance the tree."""
|
||||
if not node:
|
||||
return _Node(start, end, data)
|
||||
|
||||
# Replace the data if the interval already exists.
|
||||
if start == node.start and end == node.end:
|
||||
node.data = data
|
||||
return node
|
||||
|
||||
if start < node.start:
|
||||
node.left = self._insert(node.left, start, end, data)
|
||||
else:
|
||||
node.right = self._insert(node.right, start, end, data)
|
||||
|
||||
self._update_node_attributes(node)
|
||||
|
||||
balance = self._get_balance(node)
|
||||
|
||||
# Left Left Case
|
||||
if balance > 1 and node.left and start < node.left.start:
|
||||
return self._right_rotate(node)
|
||||
|
||||
# Right Right Case
|
||||
if balance < -1 and node.right and start >= node.right.start:
|
||||
return self._left_rotate(node)
|
||||
|
||||
# Left Right Case
|
||||
if balance > 1 and node.left and start >= node.left.start:
|
||||
node.left = self._left_rotate(node.left)
|
||||
return self._right_rotate(node)
|
||||
|
||||
# Right Left Case
|
||||
if balance < -1 and node.right and start < node.right.start:
|
||||
node.right = self._right_rotate(node.right)
|
||||
return self._left_rotate(node)
|
||||
|
||||
return node
|
||||
|
||||
def search(self, point: int) -> List[T]:
|
||||
"""
|
||||
Searches for all intervals that contain the given point.
|
||||
|
||||
Args:
|
||||
point: The point to search for.
|
||||
|
||||
Returns:
|
||||
A list of data items from all matching intervals.
|
||||
"""
|
||||
results: List[T] = []
|
||||
self._search(self.root, point, results)
|
||||
return results
|
||||
|
||||
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
|
||||
"""Recursive helper to find all overlapping intervals."""
|
||||
if node is None or point > node.max_end:
|
||||
return
|
||||
|
||||
if node.left:
|
||||
self._search(node.left, point, results)
|
||||
|
||||
if node.start <= point <= node.end:
|
||||
results.append(node.data)
|
||||
|
||||
if point >= node.start and node.right:
|
||||
self._search(node.right, point, results)
|
||||
|
||||
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||
"""
|
||||
Finds the item with the most specific (smallest) range for a given point.
|
||||
|
||||
Args:
|
||||
point: The capability to look up.
|
||||
|
||||
Returns:
|
||||
The data of the best-matching item, or None if no match is found.
|
||||
"""
|
||||
matches: List[Tuple[int, int, T]] = []
|
||||
self._find_with_intervals(self.root, point, matches)
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Return the smallest interval, sort by memory location when
|
||||
# there are multiple matches with the same interval size. This
|
||||
# is just to ensure that we can compare against a trivial
|
||||
# implementation in tests.
|
||||
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||
return best_match[2]
|
||||
|
||||
def _find_with_intervals(
|
||||
self,
|
||||
node: Optional[_Node[T]],
|
||||
point: int,
|
||||
results: List[Tuple[int, int, T]],
|
||||
) -> None:
|
||||
"""A modified search that collects interval ranges along with data."""
|
||||
if node is None or point > node.max_end:
|
||||
return
|
||||
|
||||
if node.left:
|
||||
self._find_with_intervals(node.left, point, results)
|
||||
|
||||
if node.start <= point <= node.end:
|
||||
results.append((node.start, node.end, node.data))
|
||||
|
||||
if point >= node.start and node.right:
|
||||
self._find_with_intervals(node.right, point, results)
|
@ -1,12 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Flag, auto
|
||||
from functools import lru_cache
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -17,6 +21,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from ._interval_tree import IntervalTree
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,7 +42,7 @@ class Mode(Flag):
|
||||
* `INFERENCE`: The kernel is used for inference.
|
||||
* `TRAINING`: The kernel is used for training.
|
||||
* `TORCH_COMPILE`: The kernel is used with `torch.compile`.
|
||||
* `DEFAULT`: In a kernel mapping, this kernel is used when no other mode
|
||||
* `FALLBACK`: In a kernel mapping, this kernel is used when no other mode
|
||||
matches.
|
||||
|
||||
Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE`
|
||||
@ -45,7 +50,7 @@ class Mode(Flag):
|
||||
"""
|
||||
|
||||
_NONE = 0
|
||||
DEFAULT = auto()
|
||||
FALLBACK = auto()
|
||||
TRAINING = auto()
|
||||
INFERENCE = auto()
|
||||
TORCH_COMPILE = auto()
|
||||
@ -56,8 +61,8 @@ class Mode(Flag):
|
||||
if Mode.INFERENCE in union and Mode.TRAINING in union:
|
||||
raise ValueError("Mode.INFERENCE and Mode.TRAINING are mutually exclusive.")
|
||||
|
||||
if Mode.DEFAULT in union and union != Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT cannot be combined with other modes.")
|
||||
if Mode.FALLBACK in union and union != Mode.FALLBACK:
|
||||
raise ValueError("Mode.FALLBACK cannot be combined with other modes.")
|
||||
|
||||
return union
|
||||
|
||||
@ -65,14 +70,46 @@ class Mode(Flag):
|
||||
@dataclass(frozen=True)
|
||||
class Device:
|
||||
type: str
|
||||
properties: Optional[CUDAProperties] = None
|
||||
|
||||
# In the future we might add compute capabilities, etc.
|
||||
def __post_init__(self):
|
||||
if self.properties is not None and isinstance(self.properties, CUDAProperties):
|
||||
if self.type != "cuda":
|
||||
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
|
||||
|
||||
def create_repo(self) -> _DeviceRepos:
|
||||
"""Create an appropriate repository set for this device type."""
|
||||
if self.type == "cuda":
|
||||
return _CUDARepos()
|
||||
elif self.type == "mps":
|
||||
return _MPSRepos()
|
||||
else:
|
||||
raise ValueError(f"Unknown device type: {self.type}")
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Device) and self.type == other.type
|
||||
if not isinstance(other, Device):
|
||||
return NotImplemented
|
||||
return self.type == other.type and self.properties == other.properties
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.type)
|
||||
return hash((self.type, self.properties))
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CUDAProperties:
|
||||
min_capability: int
|
||||
max_capability: int
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, CUDAProperties):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.min_capability == other.min_capability
|
||||
and self.max_capability == other.max_capability
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -104,8 +141,78 @@ class LayerRepository:
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
|
||||
ContextVar("_KERNEL_MAPPING", default={})
|
||||
class _DeviceRepos(ABC):
|
||||
"""
|
||||
Device-specific kernel layer repositories.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
"""
|
||||
Insert a repository for a specific device and mode.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepository]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._repos = {}
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
if device.type != "mps":
|
||||
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
||||
|
||||
self._repos = repos
|
||||
|
||||
|
||||
class _CUDARepos(_DeviceRepos):
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepository]]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.repos_by_capability = IntervalTree()
|
||||
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
capability = _find_capability()
|
||||
return self.repos_by_capability.find_smallest_interval(capability)
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
assert device.properties is None or isinstance(
|
||||
device.properties, CUDAProperties
|
||||
)
|
||||
|
||||
min_capability = (
|
||||
0 if device.properties is None else device.properties.min_capability
|
||||
)
|
||||
max_capability = (
|
||||
sys.maxsize
|
||||
if device.properties is None
|
||||
else device.properties.max_capability
|
||||
)
|
||||
|
||||
self.repos_by_capability.insert(min_capability, max_capability, repos)
|
||||
|
||||
|
||||
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
||||
"_KERNEL_MAPPING", default={}
|
||||
)
|
||||
|
||||
|
||||
@ -177,11 +284,12 @@ def register_kernel_mapping(
|
||||
)
|
||||
|
||||
if isinstance(new_repo, LayerRepository):
|
||||
kernel_options = {Mode.DEFAULT: new_repo}
|
||||
kernel_options = {Mode.FALLBACK: new_repo}
|
||||
else:
|
||||
kernel_options = new_repo
|
||||
|
||||
device_repo[device] = kernel_options
|
||||
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
||||
feature_repos.insert(device, kernel_options)
|
||||
|
||||
|
||||
def replace_kernel_forward_from_hub(
|
||||
@ -202,23 +310,56 @@ def replace_kernel_forward_from_hub(
|
||||
cls.kernel_layer_name = layer_name
|
||||
|
||||
|
||||
_MODE_FALLBACK_PRIORITY = {
|
||||
Mode.INFERENCE: [
|
||||
Mode.INFERENCE,
|
||||
Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
Mode.TRAINING,
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
Mode.FALLBACK,
|
||||
],
|
||||
Mode.TRAINING: [
|
||||
Mode.TRAINING,
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
Mode.FALLBACK,
|
||||
],
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: [
|
||||
Mode.INFERENCE | Mode.TORCH_COMPILE,
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
Mode.FALLBACK,
|
||||
],
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: [
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
Mode.FALLBACK,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _select_repository(
|
||||
repositories: Dict[Mode, LayerRepository],
|
||||
*,
|
||||
mode: Mode,
|
||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
||||
if mode in repositories:
|
||||
return (repositories[mode], mode)
|
||||
elif Mode.DEFAULT in repositories:
|
||||
return (repositories[Mode.DEFAULT], Mode.DEFAULT)
|
||||
else:
|
||||
return None
|
||||
# Get the fallback priority list for the requested mode
|
||||
if mode not in _MODE_FALLBACK_PRIORITY:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
|
||||
fallback_modes = _MODE_FALLBACK_PRIORITY[mode]
|
||||
|
||||
# Try each mode in priority order
|
||||
for fallback_mode in fallback_modes:
|
||||
if fallback_mode in repositories:
|
||||
return (repositories[fallback_mode], fallback_mode)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def kernelize(
|
||||
model: "nn.Module",
|
||||
*,
|
||||
mode: Mode,
|
||||
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
use_fallback: bool = True,
|
||||
):
|
||||
@ -243,8 +384,8 @@ def kernelize(
|
||||
"""
|
||||
import torch
|
||||
|
||||
if mode == Mode.DEFAULT:
|
||||
raise ValueError("Mode.DEFAULT can only be used to register kernel mappings.")
|
||||
if mode == Mode.FALLBACK:
|
||||
raise ValueError("Mode.FALLBACK can only be used to register kernel mappings.")
|
||||
|
||||
# Type check ignored because this causes a false negative on Python < 3.11.
|
||||
# Looks similar to: https://github.com/python/mypy/issues/9642
|
||||
@ -258,6 +399,7 @@ def kernelize(
|
||||
device_type = Device(type=torch.device(device).type)
|
||||
else:
|
||||
device_type = Device(device.type)
|
||||
|
||||
assert isinstance(device_type, Device)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
@ -285,12 +427,22 @@ def kernelize(
|
||||
continue
|
||||
|
||||
# Get kernel options for the device
|
||||
repos = kernel.get(device_type)
|
||||
property_repos = kernel.get(device_type.type)
|
||||
|
||||
if property_repos is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
|
||||
repos = property_repos.repos
|
||||
|
||||
if repos is None:
|
||||
if not use_fallback:
|
||||
raise ValueError(
|
||||
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
|
||||
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
|
||||
)
|
||||
_replace_forward(module, module_class)
|
||||
continue
|
||||
@ -310,6 +462,11 @@ def kernelize(
|
||||
|
||||
repo, repo_mode = repo_with_mode
|
||||
|
||||
logging.info(
|
||||
f"Using layer `{repo.layer_name}` from repo `{repo.repo_id}` (revision: {repo.revision}) for layer `{layer_name}`"
|
||||
)
|
||||
logging.debug(f"kernelize mode: {mode}, repo mode: {repo_mode}")
|
||||
|
||||
layer = _get_layer_memoize(repo, module_class)
|
||||
|
||||
# Ideally we would do validation on the mapping where we check that
|
||||
@ -411,6 +568,14 @@ def _find_device(model: "nn.Module") -> Device:
|
||||
return Device(type=param.device.type)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _find_capability() -> int:
|
||||
import torch
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(device=None)
|
||||
return major * 10 + minor
|
||||
|
||||
|
||||
def _conditionally_replace_forward(
|
||||
*,
|
||||
module: "nn.Module",
|
||||
@ -422,7 +587,7 @@ def _conditionally_replace_forward(
|
||||
|
||||
# Switch to fallback if the mode is not supported by the layer.
|
||||
# Note that this is useful even after _validate_layer_has_mode because
|
||||
# layers registered with the DEFAULT mode never get rejected by
|
||||
# layers registered with the FALLBACK mode never get rejected by
|
||||
# _validate_layer_has_mode. For such layers, we want to fall back in
|
||||
# case the layer does not support the given mode.
|
||||
needs_fallback = Mode.TORCH_COMPILE in mode and not getattr(
|
||||
|
230
tests/test_interval_tree.py
Normal file
230
tests/test_interval_tree.py
Normal file
@ -0,0 +1,230 @@
|
||||
import random
|
||||
from typing import Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
from kernels._interval_tree import IntervalTree, _Node
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SimpleIntervalStore(Generic[T]):
|
||||
"""A simple O(n) implementation that stores intervals in a list."""
|
||||
|
||||
def __init__(self):
|
||||
self.intervals: List[Tuple[int, int, T]] = []
|
||||
|
||||
def insert(self, start: int, end: int, data: T) -> None:
|
||||
"""Insert an interval into the store."""
|
||||
# Replace data if the interval already exists.
|
||||
for i, (existing_start, existing_end, existing_data) in enumerate(
|
||||
self.intervals
|
||||
):
|
||||
if existing_start == start and existing_end == end:
|
||||
self.intervals[i] = (start, end, data)
|
||||
return
|
||||
|
||||
self.intervals.append((start, end, data))
|
||||
|
||||
def find_smallest_interval(self, point: int) -> Optional[T]:
|
||||
"""Find the best match using linear search."""
|
||||
matches = []
|
||||
for start, end, data in self.intervals:
|
||||
if start <= point <= end:
|
||||
matches.append((start, end, data))
|
||||
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
# Return the smallest interval, sort by memory location when
|
||||
# there are multiple matches with the same interval size. This
|
||||
# mirrors the ordering in the intervan tree.
|
||||
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
|
||||
return best_match[2]
|
||||
|
||||
|
||||
def is_balanced(tree: IntervalTree[T]) -> bool:
|
||||
"""Check if the AVL tree is properly balanced."""
|
||||
|
||||
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
|
||||
if node is None:
|
||||
return True, 0
|
||||
|
||||
# Left and right subtrees should be balanced.
|
||||
left_balanced, left_height = check_balance(node.left)
|
||||
if not left_balanced:
|
||||
return False, -1
|
||||
|
||||
right_balanced, right_height = check_balance(node.right)
|
||||
if not right_balanced:
|
||||
return False, -1
|
||||
|
||||
# The difference in height should not exceed 1.
|
||||
if abs(left_height - right_height) > 1:
|
||||
return False, -1
|
||||
|
||||
# Check if the height is correct.
|
||||
expected_height = 1 + max(left_height, right_height)
|
||||
if node.height != expected_height:
|
||||
return False, -1
|
||||
|
||||
return True, expected_height
|
||||
|
||||
balanced, _ = check_balance(tree.root)
|
||||
return balanced
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def populated_tree() -> IntervalTree[str]:
|
||||
"""Provides a pre-populated IntervalTree for testing."""
|
||||
tree = IntervalTree[str]()
|
||||
kernels = [
|
||||
(80, 89, "Kernel_A_General_80_89"),
|
||||
(86, 89, "Kernel_B_Ampere_86_89"),
|
||||
(80, 86, "Kernel_C_Older_Ampere_80_86"),
|
||||
(70, 75, "Kernel_D_Volta_70_75"),
|
||||
(86, 87, "Kernel_E_Specific_86_87"),
|
||||
]
|
||||
for start, end, name in kernels:
|
||||
tree.insert(start, end, name)
|
||||
return tree
|
||||
|
||||
|
||||
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
|
||||
# Check that the smallest inteval is selected when there are
|
||||
# multiple matching intervals.
|
||||
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
|
||||
|
||||
|
||||
def test_find_single_match(populated_tree):
|
||||
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
|
||||
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
|
||||
|
||||
|
||||
def test_no_match_outside_all_ranges(populated_tree):
|
||||
# Check that no interval is found when the value is out of range
|
||||
# (too small/too large).
|
||||
assert populated_tree.find_smallest_interval(65) is None
|
||||
assert populated_tree.find_smallest_interval(95) is None
|
||||
|
||||
|
||||
def test_no_match_in_gap_between_ranges(populated_tree):
|
||||
# Check that no interval is found when the value is between two
|
||||
# intervals.
|
||||
assert populated_tree.find_smallest_interval(78) is None
|
||||
|
||||
|
||||
def test_boundary_conditions_start_and_end(populated_tree):
|
||||
# Test exact upper/lower bounds of intervals.
|
||||
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
|
||||
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
|
||||
|
||||
|
||||
def test_empty_tree():
|
||||
# Searching in an empty tree should return None.
|
||||
empty_tree = IntervalTree[str]()
|
||||
assert empty_tree.find_smallest_interval(100) is None
|
||||
|
||||
|
||||
def test_multiple_equally_specific_matches():
|
||||
# Check that we pick the match in a stable way when there is are
|
||||
# multiple matching intervals with the same size.
|
||||
tree = IntervalTree[str]()
|
||||
str1 = "First_Narrow_Kernel"
|
||||
str2 = "Second_Narrow_Kernel"
|
||||
tree.insert(10, 20, "Wide_Kernel")
|
||||
tree.insert(12, 17, str1)
|
||||
tree.insert(14, 19, str2)
|
||||
|
||||
if id(str1) < id(str2):
|
||||
assert tree.find_smallest_interval(15) == str1
|
||||
else:
|
||||
assert tree.find_smallest_interval(15) == str2
|
||||
|
||||
|
||||
def test_property_based_interval_tree():
|
||||
# Quick-check property-based testing:
|
||||
#
|
||||
# - Verify that the tree is balanced after each insertion.
|
||||
# - Verify the query against a simple list-based implementation.
|
||||
|
||||
random.seed(42) # For reproducible tests
|
||||
|
||||
test_points = list(range(0, 101))
|
||||
|
||||
for _ in range(5):
|
||||
tree = IntervalTree[str]()
|
||||
simple = SimpleIntervalStore[str]()
|
||||
|
||||
intervals = []
|
||||
for i in range(100):
|
||||
start = random.randint(0, 90)
|
||||
end = random.randint(start, 100)
|
||||
data = f"interval_{i}_s{start}_e{end}"
|
||||
intervals.append((start, end, data))
|
||||
|
||||
for i, (start, end, data) in enumerate(intervals):
|
||||
tree.insert(start, end, data)
|
||||
simple.insert(start, end, data)
|
||||
|
||||
# Check that tree is still balanced
|
||||
assert is_balanced(
|
||||
tree
|
||||
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
|
||||
|
||||
for point in test_points:
|
||||
tree_result = tree.find_smallest_interval(point)
|
||||
simple_result = simple.find_smallest_interval(point)
|
||||
|
||||
assert tree_result == simple_result, (
|
||||
f"Mismatch for point {point} after inserting {i+1} intervals. "
|
||||
f"Tree: {tree_result}, Simple: {simple_result}. "
|
||||
f"Last inserted: ({start}, {end})"
|
||||
)
|
||||
|
||||
|
||||
def test_property_based_edge_cases():
|
||||
random.seed(123)
|
||||
|
||||
tree = IntervalTree[str]()
|
||||
simple = SimpleIntervalStore[str]()
|
||||
|
||||
# Single-point intervals.
|
||||
for i in range(10):
|
||||
point = random.randint(0, 100)
|
||||
data = f"single_point_{i}_{point}"
|
||||
tree.insert(point, point, data)
|
||||
simple.insert(point, point, data)
|
||||
|
||||
assert is_balanced(
|
||||
tree
|
||||
), f"Tree unbalanced after inserting single point {point}"
|
||||
|
||||
# Test the exact point and neighbors
|
||||
for test_point in [point - 1, point, point + 1]:
|
||||
if 0 <= test_point <= 100:
|
||||
tree_result = tree.find_smallest_interval(test_point)
|
||||
simple_result = simple.find_smallest_interval(test_point)
|
||||
assert tree_result == simple_result
|
||||
|
||||
|
||||
def test_unique_intervals_override():
|
||||
"""Test that inserting an interval with the same start/end overrides the previous value."""
|
||||
tree = IntervalTree[str]()
|
||||
|
||||
tree.insert(10, 20, "original_value")
|
||||
assert tree.find_smallest_interval(15) == "original_value"
|
||||
|
||||
tree.insert(10, 20, "new_value")
|
||||
assert tree.find_smallest_interval(15) == "new_value"
|
||||
|
||||
tree.insert(10, 25, "different_interval")
|
||||
results = tree.search(15)
|
||||
assert "new_value" in results
|
||||
assert "different_interval" in results
|
||||
assert len(results) == 2
|
||||
|
||||
tree.insert(10, 20, "final_value")
|
||||
assert tree.find_smallest_interval(15) == "final_value"
|
||||
|
||||
assert is_balanced(tree)
|
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
@ -13,7 +14,12 @@ from kernels import (
|
||||
register_kernel_mapping,
|
||||
use_kernel_forward_from_hub,
|
||||
)
|
||||
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
|
||||
from kernels.layer import (
|
||||
_KERNEL_MAPPING,
|
||||
CUDAProperties,
|
||||
_validate_layer,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
@ -118,6 +124,55 @@ def test_hub_forward(cls, device):
|
||||
assert silu_and_mul_with_kernel.n_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_capability():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=75, max_capability=sys.maxsize
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
|
||||
# Check that we called out to the kernel.
|
||||
assert linear.n_calls == 0
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
Device(
|
||||
type="cuda",
|
||||
properties=CUDAProperties(
|
||||
min_capability=sys.maxsize, max_capability=sys.maxsize
|
||||
),
|
||||
): LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
|
||||
# Check that we didn't call out to the kernel because there is
|
||||
# is no kernel with a matching capability..
|
||||
assert linear.n_calls == 1
|
||||
|
||||
|
||||
def test_layer_fallback_works():
|
||||
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
|
||||
class SiluAndMulWithKernelFallback(SiluAndMul):
|
||||
@ -182,6 +237,7 @@ def test_torch_compile_layer_with_fallback(cls, device):
|
||||
torch.testing.assert_close(Y_compiled, Y)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_mapping_contexts():
|
||||
assert set(_KERNEL_MAPPING.get().keys()) == {
|
||||
"SiluAndMul",
|
||||
@ -225,9 +281,7 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -238,9 +292,7 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -249,9 +301,7 @@ def test_mapping_contexts():
|
||||
"SiluAndMul",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||
== "kernels-community/non-existing"
|
||||
)
|
||||
|
||||
@ -262,9 +312,7 @@ def test_mapping_contexts():
|
||||
"TestKernel",
|
||||
}
|
||||
assert (
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
|
||||
Mode.DEFAULT
|
||||
].repo_id
|
||||
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.FALLBACK].repo_id
|
||||
== "kernels-community/activation"
|
||||
)
|
||||
|
||||
@ -303,6 +351,7 @@ def test_validate_kernel_layer():
|
||||
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_invalid_mode_for_mapping_rejected():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
@ -322,6 +371,7 @@ def test_invalid_mode_for_mapping_rejected():
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_kernel_modes():
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
@ -350,6 +400,11 @@ def test_kernel_modes():
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: register a kernel just for training. If no base kernel
|
||||
# layer is registered, we fall back to the original layer.
|
||||
with use_kernel_mapping(
|
||||
@ -367,16 +422,22 @@ def test_kernel_modes():
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 1
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Training has a kernel, so fallback.
|
||||
assert linear.n_calls == 1
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# No kernel for training + torch.compile, so fallback.
|
||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||
assert linear.n_calls == 1
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 3: register a kernel just for training and one for fallback.
|
||||
@ -384,7 +445,7 @@ def test_kernel_modes():
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.DEFAULT: LayerRepository(
|
||||
Mode.FALLBACK: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
@ -399,17 +460,23 @@ def test_kernel_modes():
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
# Falls back to TRAINING.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Uses the training kernel.
|
||||
# Falls back to the TRAINING kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses the base kernel.
|
||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE falls back to FALLBACK kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 4: register a kernel with two preferences.
|
||||
@ -429,18 +496,23 @@ def test_kernel_modes():
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# No inference kernel, so fallback.
|
||||
assert linear.n_calls == 3
|
||||
# Falls back to the TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback.
|
||||
assert linear.n_calls == 4
|
||||
# TRAINING can fall back to TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# We do have a training + torch.compile kernel.
|
||||
assert linear.n_calls == 4
|
||||
# Uses TRAINING | TORCH_COMPILE kernel.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear)
|
||||
linear(X)
|
||||
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
@ -497,12 +569,291 @@ def test_invalid_mode_rejected():
|
||||
_ = Mode.INFERENCE | Mode.TRAINING
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be combined with other modes"):
|
||||
_ = Mode.DEFAULT | Mode.TORCH_COMPILE
|
||||
_ = Mode.FALLBACK | Mode.TORCH_COMPILE
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="can only be used to register kernel mappings"
|
||||
):
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.DEFAULT)
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.FALLBACK)
|
||||
|
||||
with pytest.raises(ValueError, match="mode must contain"):
|
||||
kernelize(torch.nn.Linear(32, 32), mode=Mode.TORCH_COMPILE)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_kernel_modes_inference():
|
||||
"""Test inference-specific fallback scenarios."""
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: register a kernel just for inference
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback to original
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 2: register a kernel just for inference + torch.compile
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
linear(X)
|
||||
# INFERENCE falls back to INFERENCE | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback to original
|
||||
assert linear.n_calls == 3
|
||||
|
||||
# Case 3: register both inference kernels
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses exact INFERENCE kernel
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 3
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# No training kernel, so fallback to original
|
||||
assert linear.n_calls == 4
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_kernel_modes_mixed():
|
||||
"""Test mixed training and inference kernel scenarios."""
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: register both base inference and training kernels
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# INFERENCE | TORCH_COMPILE cannot fall back to INFERENCE kernel, so uses original
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE cannot fall back to TRAINING kernel, so uses original
|
||||
assert linear.n_calls == 2
|
||||
|
||||
# Case 2: register all four kernel modes
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# Uses exact INFERENCE kernel
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# Uses exact TRAINING kernel
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses exact INFERENCE | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 2
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# Uses exact TRAINING | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_kernel_modes_cross_fallback():
|
||||
"""Test cross-mode fallback scenarios from inference to training modes."""
|
||||
linear = TorchLinearWithCounter(32, 32).to("cuda")
|
||||
|
||||
# Case 1: Only training kernel registered - inference should fall back to training
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# INFERENCE falls back to TRAINING kernel
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# TRAINING uses the kernel directly
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 2: Only training + torch.compile kernel registered
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.INFERENCE)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# INFERENCE falls back to TRAINING | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.INFERENCE | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# INFERENCE | TORCH_COMPILE falls back to TRAINING | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
linear(X)
|
||||
# TRAINING falls back to TRAINING | TORCH_COMPILE kernel
|
||||
assert linear.n_calls == 0
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE uses the kernel directly
|
||||
assert linear.n_calls == 0
|
||||
|
||||
# Case 3: Test that training modes don't fall back to inference modes
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.INFERENCE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
):
|
||||
kernelize(linear, mode=Mode.TRAINING)
|
||||
X = torch.randn(10, 32, device="cuda")
|
||||
linear(X)
|
||||
# TRAINING should NOT fall back to inference kernels, use original
|
||||
assert linear.n_calls == 1
|
||||
|
||||
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
|
||||
assert linear.n_calls == 2
|
||||
|
Reference in New Issue
Block a user