Compare commits

..

1 Commits

Author SHA1 Message Date
b7b5f40143 Set version to 0.7.0 2025-07-07 13:09:01 +00:00
7 changed files with 45 additions and 756 deletions

View File

@ -49,16 +49,13 @@ A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. `kernelize` can be used as follows:
Hub kernels are registered. Kernelize can be used as follows:
```python
model = MyModel(...)
model = kernelize(model, mode=Mode.INFERENCE)
```
The `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
The `mode` specifies that the model will be used in inference. Similarly,
you can ask `kernelize` to prepare the model for training:
@ -67,11 +64,8 @@ model = MyModel(...)
model = kernelize(model, mode=Mode.TRAINING)
```
When the `mode` argument is not specified, the
`Mode.TRAINING | Mode.TORCH_COMPILE` mode is used as the default. This mode
aligns most closely with pure PyTorch layers (which generally support backward
passes and `torch.compile`). However, this mode can also lead to fewer
kernels being used, since not all kernels support training or `torch.compile`.
**Note:** the `kernelize` function modifies the model in-place, the model
itself is returned as a convenience.
### Kernel device
@ -181,19 +175,19 @@ so. For example:
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": {
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
Mode.DEFAULT: LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Mode.INFERENCE: LayerRepository(
repo_id="kernels-community/activation-inference-optimized",
layer_name="SiluAndMul",
),
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-community/activation-training-optimized",
layer_name="SiluAndMul",
),
}
}
}
```
@ -201,78 +195,3 @@ kernel_layer_mapping = {
In this case, modes other than `Mode.INFERENCE` and
`Mode.TRAINING | Mode.TORCH_COMPILE` will be kernelized using
`kernels-community/activation`.
### Mode fallback behavior
As described above, if there is no exact match for the mode given to
`kernelize`, it will try to use the kernel registered for `Mode.DEFAULT`.
If the `Mode.DEFAULT` kernel does not support the `kernelize` mode, the
original layer's `forward` method will be used instead.
As an example, suppose that two kernels were registered for a layer:
1. Kernel `A` is registered for `Mode.DEFAULT`. This kernel supports training
(backward), but not `torch.compile`.
2. Kernel `B` is registered for `Mode.INFERENCE | Mode.COMPILE` and supports
`torch.compile`.
`kernelize` modes will then behave as follows:
- `Mode.INFERENCE | Mode.COMPILE`` uses kernel `B`: exact match.
- `Mode.INFERENCE` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`.
- `Mode.TRAIN` uses kernel `A`: no exact match, so fall back to
`Mode.DEFAULT`, which supports training.
- `Mode.TRAIN | Mode.COMPILE`: uses the original layer's
`forward`: no exact match, falling back to `Mode.DEFAULT` is not possible
because kernel `A` does not support `torch.compile`.
### Registering kernels for specific CUDA capabilities
Some kernels only work with newer CUDA architectures. For instance, some
kernels require capability 9.0 for the TMA unit on Hopper GPUs. `kernels`
supports registering layers for a range of CUDA capabilities. To do so,
you need to register the layer for a `Device` with type `cuda` and
set the supported range of CUDA capabilities with using `CUDAProperties`:
```python
kernel_layer_mapping = {
"SiluAndMul": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=89
),
): LayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
),
Device(
type="cuda",
properties=CUDAProperties(
min_capability=90, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-community/activation-hopper",
layer_name="SiluAndMul",
),
}
}
```
Capabilities behave as follows:
- The minimum and maximum capabilities are inclusive.
- When a new kernel is registered with the same min/max capabilities as
an existing kernel, the new kernel will replace the old kernel.
- When there are multiple kernels that support a capability, the kernel
with the smaller capability interval will be used. E.g. given:
- `KernelA` with `min_capability=80` and `max_capability=89`;
- `KernelB` with `min_capability=75` and `max_capability=89`;
- `kernelize` runs on a system with capability 8.6.
Then `KernelA` will be used because the interval 80..89 is smaller
than 75..89. The motivation is that kernels with smaller ranges
tend to be more optimized for a specific set of GPUs. **This behavior
might still change in the future.**

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.7.0.dev0"
version = "0.7.0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },

View File

@ -1,5 +1,4 @@
from kernels.layer import (
CUDAProperties,
Device,
LayerRepository,
Mode,
@ -19,7 +18,6 @@ from kernels.utils import (
)
__all__ = [
"CUDAProperties",
"Device",
"LayerRepository",
"Mode",

View File

@ -1,200 +0,0 @@
# AVL-balanced interval trees. We could use the intervaltree
# packages, but it seems unmaintained and does not have type
# annotations.
from typing import Generic, List, Optional, Tuple, TypeVar
T = TypeVar("T")
class _Node(Generic[T]):
"""A node in the interval tree."""
def __init__(self, start: int, end: int, data: T):
self.start: int = start
self.end: int = end
self.data: T = data
self.max_end: int = end
self.left: Optional["_Node[T]"] = None
self.right: Optional["_Node[T]"] = None
self.height: int = 1
def __repr__(self) -> str:
return f"Node({self.start}, {self.end})"
class IntervalTree(Generic[T]):
"""A data structure to hold and query (unique) intervals."""
root: Optional[_Node[T]]
def __init__(self):
self.root = None
def insert(self, start: int, end: int, data: T) -> None:
"""
Inserts a new interval into the tree.
Args:
start: The starting point of the interval.
end: The ending point of the interval.
data: The data associated with this interval.
"""
self.root = self._insert(self.root, start, end, data)
def _get_height(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return node.height
def _get_balance(self, node: Optional[_Node[T]]) -> int:
if not node:
return 0
return self._get_height(node.left) - self._get_height(node.right)
def _update_node_attributes(self, node: _Node[T]) -> None:
node.height = 1 + max(self._get_height(node.left), self._get_height(node.right))
node.max_end = node.end
if node.left:
node.max_end = max(node.max_end, node.left.max_end)
if node.right:
node.max_end = max(node.max_end, node.right.max_end)
def _right_rotate(self, y: _Node[T]) -> _Node[T]:
"""Performs a right rotation."""
x = y.left
assert x is not None
T2 = x.right
x.right = y
y.left = T2
self._update_node_attributes(y)
self._update_node_attributes(x)
return x
def _left_rotate(self, x: _Node[T]) -> _Node[T]:
"""Performs a left rotation."""
y = x.right
assert y is not None
T2 = y.left
y.left = x
x.right = T2
self._update_node_attributes(x)
self._update_node_attributes(y)
return y
def _insert(
self, node: Optional[_Node[T]], start: int, end: int, data: T
) -> _Node[T]:
"""Recursive helper to insert a new node and balance the tree."""
if not node:
return _Node(start, end, data)
# Replace the data if the interval already exists.
if start == node.start and end == node.end:
node.data = data
return node
if start < node.start:
node.left = self._insert(node.left, start, end, data)
else:
node.right = self._insert(node.right, start, end, data)
self._update_node_attributes(node)
balance = self._get_balance(node)
# Left Left Case
if balance > 1 and node.left and start < node.left.start:
return self._right_rotate(node)
# Right Right Case
if balance < -1 and node.right and start >= node.right.start:
return self._left_rotate(node)
# Left Right Case
if balance > 1 and node.left and start >= node.left.start:
node.left = self._left_rotate(node.left)
return self._right_rotate(node)
# Right Left Case
if balance < -1 and node.right and start < node.right.start:
node.right = self._right_rotate(node.right)
return self._left_rotate(node)
return node
def search(self, point: int) -> List[T]:
"""
Searches for all intervals that contain the given point.
Args:
point: The point to search for.
Returns:
A list of data items from all matching intervals.
"""
results: List[T] = []
self._search(self.root, point, results)
return results
def _search(self, node: Optional[_Node[T]], point: int, results: List[T]) -> None:
"""Recursive helper to find all overlapping intervals."""
if node is None or point > node.max_end:
return
if node.left:
self._search(node.left, point, results)
if node.start <= point <= node.end:
results.append(node.data)
if point >= node.start and node.right:
self._search(node.right, point, results)
def find_smallest_interval(self, point: int) -> Optional[T]:
"""
Finds the item with the most specific (smallest) range for a given point.
Args:
point: The capability to look up.
Returns:
The data of the best-matching item, or None if no match is found.
"""
matches: List[Tuple[int, int, T]] = []
self._find_with_intervals(self.root, point, matches)
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# is just to ensure that we can compare against a trivial
# implementation in tests.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def _find_with_intervals(
self,
node: Optional[_Node[T]],
point: int,
results: List[Tuple[int, int, T]],
) -> None:
"""A modified search that collects interval ranges along with data."""
if node is None or point > node.max_end:
return
if node.left:
self._find_with_intervals(node.left, point, results)
if node.start <= point <= node.end:
results.append((node.start, node.end, node.data))
if point >= node.start and node.right:
self._find_with_intervals(node.right, point, results)

View File

@ -2,14 +2,11 @@ from __future__ import annotations
import inspect
import os
import sys
import warnings
from abc import ABC, abstractmethod
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Flag, auto
from functools import lru_cache
from types import MethodType
from typing import (
TYPE_CHECKING,
@ -20,7 +17,6 @@ from typing import (
Union,
)
from ._interval_tree import IntervalTree
from .utils import get_kernel
if TYPE_CHECKING:
@ -69,46 +65,14 @@ class Mode(Flag):
@dataclass(frozen=True)
class Device:
type: str
properties: Optional[CUDAProperties] = None
def __post_init__(self):
if self.properties is not None and isinstance(self.properties, CUDAProperties):
if self.type != "cuda":
raise ValueError("CUDAProperties is only supported for 'cuda' devices.")
def create_repo(self) -> _DeviceRepos:
"""Create an appropriate repository set for this device type."""
if self.type == "cuda":
return _CUDARepos()
elif self.type == "mps":
return _MPSRepos()
else:
raise ValueError(f"Unknown device type: {self.type}")
# In the future we might add compute capabilities, etc.
def __eq__(self, other):
if not isinstance(other, Device):
return NotImplemented
return self.type == other.type and self.properties == other.properties
return isinstance(other, Device) and self.type == other.type
def __hash__(self):
return hash((self.type, self.properties))
@dataclass(frozen=True)
class CUDAProperties:
min_capability: int
max_capability: int
def __eq__(self, other):
if not isinstance(other, CUDAProperties):
return NotImplemented
return (
self.min_capability == other.min_capability
and self.max_capability == other.max_capability
)
def __hash__(self):
return hash((self.min_capability, self.max_capability))
return hash(self.type)
@dataclass
@ -140,78 +104,8 @@ class LayerRepository:
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
class _DeviceRepos(ABC):
"""
Device-specific kernel layer repositories.
"""
@property
@abstractmethod
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]: ...
@abstractmethod
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
"""
Insert a repository for a specific device and mode.
"""
...
class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepository]
def __init__(self):
super().__init__()
self._repos = {}
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]:
return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
if device.type != "mps":
raise ValueError(f"Device type must be 'mps', got {device.type}")
self._repos = repos
class _CUDARepos(_DeviceRepos):
_repos: IntervalTree[Dict[Mode, LayerRepository]]
def __init__(self):
super().__init__()
self.repos_by_capability = IntervalTree()
@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepository]]:
capability = _find_capability()
return self.repos_by_capability.find_smallest_interval(capability)
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
assert device.properties is None or isinstance(
device.properties, CUDAProperties
)
min_capability = (
0 if device.properties is None else device.properties.min_capability
)
max_capability = (
sys.maxsize
if device.properties is None
else device.properties.max_capability
)
self.repos_by_capability.insert(min_capability, max_capability, repos)
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
"_KERNEL_MAPPING", default={}
_KERNEL_MAPPING: ContextVar[Dict[str, Dict[Device, Dict[Mode, LayerRepository]]]] = (
ContextVar("_KERNEL_MAPPING", default={})
)
@ -287,8 +181,7 @@ def register_kernel_mapping(
else:
kernel_options = new_repo
feature_repos = device_repo.setdefault(device.type, device.create_repo())
feature_repos.insert(device, kernel_options)
device_repo[device] = kernel_options
def replace_kernel_forward_from_hub(
@ -325,7 +218,7 @@ def _select_repository(
def kernelize(
model: "nn.Module",
*,
mode: Mode = Mode.TRAINING | Mode.TORCH_COMPILE,
mode: Mode,
device: Optional[Union[str, "torch.device"]] = None,
use_fallback: bool = True,
):
@ -365,7 +258,6 @@ def kernelize(
device_type = Device(type=torch.device(device).type)
else:
device_type = Device(device.type)
assert isinstance(device_type, Device)
for _, module in model.named_modules():
@ -393,22 +285,12 @@ def kernelize(
continue
# Get kernel options for the device
property_repos = kernel.get(device_type.type)
if property_repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
continue
repos = property_repos.repos
repos = kernel.get(device_type)
if repos is None:
if not use_fallback:
raise ValueError(
f"No layer mapping for `{layer_name}` device `{device_type}` with the right properties"
f"No layer mapping for `{layer_name}` with device type `{device_type}`"
)
_replace_forward(module, module_class)
continue
@ -529,14 +411,6 @@ def _find_device(model: "nn.Module") -> Device:
return Device(type=param.device.type)
@lru_cache
def _find_capability() -> int:
import torch
major, minor = torch.cuda.get_device_capability(device=None)
return major * 10 + minor
def _conditionally_replace_forward(
*,
module: "nn.Module",

View File

@ -1,230 +0,0 @@
import random
from typing import Generic, List, Optional, Tuple, TypeVar
import pytest
from kernels._interval_tree import IntervalTree, _Node
T = TypeVar("T")
class SimpleIntervalStore(Generic[T]):
"""A simple O(n) implementation that stores intervals in a list."""
def __init__(self):
self.intervals: List[Tuple[int, int, T]] = []
def insert(self, start: int, end: int, data: T) -> None:
"""Insert an interval into the store."""
# Replace data if the interval already exists.
for i, (existing_start, existing_end, existing_data) in enumerate(
self.intervals
):
if existing_start == start and existing_end == end:
self.intervals[i] = (start, end, data)
return
self.intervals.append((start, end, data))
def find_smallest_interval(self, point: int) -> Optional[T]:
"""Find the best match using linear search."""
matches = []
for start, end, data in self.intervals:
if start <= point <= end:
matches.append((start, end, data))
if not matches:
return None
# Return the smallest interval, sort by memory location when
# there are multiple matches with the same interval size. This
# mirrors the ordering in the intervan tree.
best_match = min(matches, key=lambda x: (x[1] - x[0], id(x[2])))
return best_match[2]
def is_balanced(tree: IntervalTree[T]) -> bool:
"""Check if the AVL tree is properly balanced."""
def check_balance(node: Optional[_Node[T]]) -> Tuple[bool, int]:
if node is None:
return True, 0
# Left and right subtrees should be balanced.
left_balanced, left_height = check_balance(node.left)
if not left_balanced:
return False, -1
right_balanced, right_height = check_balance(node.right)
if not right_balanced:
return False, -1
# The difference in height should not exceed 1.
if abs(left_height - right_height) > 1:
return False, -1
# Check if the height is correct.
expected_height = 1 + max(left_height, right_height)
if node.height != expected_height:
return False, -1
return True, expected_height
balanced, _ = check_balance(tree.root)
return balanced
@pytest.fixture
def populated_tree() -> IntervalTree[str]:
"""Provides a pre-populated IntervalTree for testing."""
tree = IntervalTree[str]()
kernels = [
(80, 89, "Kernel_A_General_80_89"),
(86, 89, "Kernel_B_Ampere_86_89"),
(80, 86, "Kernel_C_Older_Ampere_80_86"),
(70, 75, "Kernel_D_Volta_70_75"),
(86, 87, "Kernel_E_Specific_86_87"),
]
for start, end, name in kernels:
tree.insert(start, end, name)
return tree
def test_find_smallest_interval_match_with_multiple_overlaps(populated_tree):
# Check that the smallest inteval is selected when there are
# multiple matching intervals.
assert populated_tree.find_smallest_interval(86) == "Kernel_E_Specific_86_87"
def test_find_single_match(populated_tree):
assert populated_tree.find_smallest_interval(72) == "Kernel_D_Volta_70_75"
assert populated_tree.find_smallest_interval(75) == "Kernel_D_Volta_70_75"
def test_no_match_outside_all_ranges(populated_tree):
# Check that no interval is found when the value is out of range
# (too small/too large).
assert populated_tree.find_smallest_interval(65) is None
assert populated_tree.find_smallest_interval(95) is None
def test_no_match_in_gap_between_ranges(populated_tree):
# Check that no interval is found when the value is between two
# intervals.
assert populated_tree.find_smallest_interval(78) is None
def test_boundary_conditions_start_and_end(populated_tree):
# Test exact upper/lower bounds of intervals.
assert populated_tree.find_smallest_interval(80) == "Kernel_C_Older_Ampere_80_86"
assert populated_tree.find_smallest_interval(89) == "Kernel_B_Ampere_86_89"
def test_empty_tree():
# Searching in an empty tree should return None.
empty_tree = IntervalTree[str]()
assert empty_tree.find_smallest_interval(100) is None
def test_multiple_equally_specific_matches():
# Check that we pick the match in a stable way when there is are
# multiple matching intervals with the same size.
tree = IntervalTree[str]()
str1 = "First_Narrow_Kernel"
str2 = "Second_Narrow_Kernel"
tree.insert(10, 20, "Wide_Kernel")
tree.insert(12, 17, str1)
tree.insert(14, 19, str2)
if id(str1) < id(str2):
assert tree.find_smallest_interval(15) == str1
else:
assert tree.find_smallest_interval(15) == str2
def test_property_based_interval_tree():
# Quick-check property-based testing:
#
# - Verify that the tree is balanced after each insertion.
# - Verify the query against a simple list-based implementation.
random.seed(42) # For reproducible tests
test_points = list(range(0, 101))
for _ in range(5):
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
intervals = []
for i in range(100):
start = random.randint(0, 90)
end = random.randint(start, 100)
data = f"interval_{i}_s{start}_e{end}"
intervals.append((start, end, data))
for i, (start, end, data) in enumerate(intervals):
tree.insert(start, end, data)
simple.insert(start, end, data)
# Check that tree is still balanced
assert is_balanced(
tree
), f"Tree became unbalanced after inserting interval {i}: ({start}, {end})"
for point in test_points:
tree_result = tree.find_smallest_interval(point)
simple_result = simple.find_smallest_interval(point)
assert tree_result == simple_result, (
f"Mismatch for point {point} after inserting {i+1} intervals. "
f"Tree: {tree_result}, Simple: {simple_result}. "
f"Last inserted: ({start}, {end})"
)
def test_property_based_edge_cases():
random.seed(123)
tree = IntervalTree[str]()
simple = SimpleIntervalStore[str]()
# Single-point intervals.
for i in range(10):
point = random.randint(0, 100)
data = f"single_point_{i}_{point}"
tree.insert(point, point, data)
simple.insert(point, point, data)
assert is_balanced(
tree
), f"Tree unbalanced after inserting single point {point}"
# Test the exact point and neighbors
for test_point in [point - 1, point, point + 1]:
if 0 <= test_point <= 100:
tree_result = tree.find_smallest_interval(test_point)
simple_result = simple.find_smallest_interval(test_point)
assert tree_result == simple_result
def test_unique_intervals_override():
"""Test that inserting an interval with the same start/end overrides the previous value."""
tree = IntervalTree[str]()
tree.insert(10, 20, "original_value")
assert tree.find_smallest_interval(15) == "original_value"
tree.insert(10, 20, "new_value")
assert tree.find_smallest_interval(15) == "new_value"
tree.insert(10, 25, "different_interval")
results = tree.search(15)
assert "new_value" in results
assert "different_interval" in results
assert len(results) == 2
tree.insert(10, 20, "final_value")
assert tree.find_smallest_interval(15) == "final_value"
assert is_balanced(tree)

View File

@ -1,4 +1,3 @@
import sys
from contextlib import nullcontext
import pytest
@ -14,12 +13,7 @@ from kernels import (
register_kernel_mapping,
use_kernel_forward_from_hub,
)
from kernels.layer import (
_KERNEL_MAPPING,
CUDAProperties,
_validate_layer,
use_kernel_mapping,
)
from kernels.layer import _KERNEL_MAPPING, _validate_layer, use_kernel_mapping
kernel_layer_mapping = {
"SiluAndMul": {
@ -124,55 +118,6 @@ def test_hub_forward(cls, device):
assert silu_and_mul_with_kernel.n_calls == 1
@pytest.mark.linux_only
def test_capability():
linear = TorchLinearWithCounter(32, 32).to("cuda")
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=75, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we called out to the kernel.
assert linear.n_calls == 0
with use_kernel_mapping(
{
"Linear": {
Device(
type="cuda",
properties=CUDAProperties(
min_capability=sys.maxsize, max_capability=sys.maxsize
),
): LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
}
}
):
kernelize(linear, mode=Mode.INFERENCE)
X = torch.randn(10, 32, device="cuda")
linear(X)
# Check that we didn't call out to the kernel because there is
# is no kernel with a matching capability..
assert linear.n_calls == 1
def test_layer_fallback_works():
@use_kernel_forward_from_hub("SiluAndMulNonExisting")
class SiluAndMulWithKernelFallback(SiluAndMul):
@ -237,7 +182,6 @@ def test_torch_compile_layer_with_fallback(cls, device):
torch.testing.assert_close(Y_compiled, Y)
@pytest.mark.linux_only
def test_mapping_contexts():
assert set(_KERNEL_MAPPING.get().keys()) == {
"SiluAndMul",
@ -281,7 +225,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -292,7 +238,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -301,7 +249,9 @@ def test_mapping_contexts():
"SiluAndMul",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/non-existing"
)
@ -312,7 +262,9 @@ def test_mapping_contexts():
"TestKernel",
}
assert (
_KERNEL_MAPPING.get()["SiluAndMul"]["cuda"].repos[Mode.DEFAULT].repo_id
_KERNEL_MAPPING.get()["SiluAndMul"][Device(type="cuda")][
Mode.DEFAULT
].repo_id
== "kernels-community/activation"
)
@ -351,7 +303,6 @@ def test_validate_kernel_layer():
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
@pytest.mark.linux_only
def test_invalid_mode_for_mapping_rejected():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -371,7 +322,6 @@ def test_invalid_mode_for_mapping_rejected():
kernelize(linear, mode=Mode.TRAINING)
@pytest.mark.linux_only
def test_kernel_modes():
linear = TorchLinearWithCounter(32, 32).to("cuda")
@ -400,11 +350,6 @@ def test_kernel_modes():
linear(X)
assert linear.n_calls == 0
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 0
# Case 2: register a kernel just for training. If no base kernel
# layer is registered, we fall back to the original layer.
with use_kernel_mapping(
@ -434,12 +379,6 @@ def test_kernel_modes():
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 2
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# No kernel for training + torch.compile, so fallback.
assert linear.n_calls == 3
# Case 3: register a kernel just for training and one for fallback.
with use_kernel_mapping(
{
@ -461,23 +400,17 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda")
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# Uses the training kernel.
assert linear.n_calls == 3
assert linear.n_calls == 2
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
# Uses the base kernel.
assert linear.n_calls == 3
assert linear.n_calls == 2
# Case 4: register a kernel with two preferences.
with use_kernel_mapping(
@ -497,22 +430,17 @@ def test_kernel_modes():
X = torch.randn(10, 32, device="cuda")
linear(X)
# No inference kernel, so fallback.
assert linear.n_calls == 4
assert linear.n_calls == 3
kernelize(linear, mode=Mode.TRAINING)
linear(X)
# No training kernel, so fallback.
assert linear.n_calls == 5
assert linear.n_calls == 4
kernelize(linear, mode=Mode.TRAINING | Mode.TORCH_COMPILE)
linear(X)
# We do have a training + torch.compile kernel.
assert linear.n_calls == 5
# Same as previous, since TRAINING | TORCH_COMPILE is the default.
kernelize(linear)
linear(X)
assert linear.n_calls == 5
assert linear.n_calls == 4
@pytest.mark.linux_only