Add support for project-wide locking of layers (#114)

This change adds `LockedLayerRepository` as an alternative to
`LayerRepository`. `LockedLayerRepository` allows for locking all kernel
layers that are used at the project level. Example usage:

```
with use_kernel_mapping(
    {
        "SomeLayer": {
            "cuda": LockedLayerRepository(
                repo_id="some-org/some-layer",
                layer_name="SomeLayer",
            )
        },
    }
):
    layer = kernelize(layer, device="cuda", mode=Mode.INFERENCE)
```

This requires that the project has a `pyproject.toml` with kernel
version specifications and `kernel.lock` with the locked kernels.
This commit is contained in:
Daniël de Kok
2025-07-23 09:37:05 +02:00
committed by GitHub
parent 4a04c005e3
commit 81088d44e8
6 changed files with 155 additions and 21 deletions

View File

@ -57,7 +57,7 @@ the Hub.
## 📚 Documentation ## 📚 Documentation
- [Using layers](docs/layers.md) - [Using layers](docs/layers.md)
- [Locking kernel versions](docs/locking.md) - [Locking kernel/layer versions](docs/locking.md)
- [Environment variables](docs/env.md) - [Environment variables](docs/env.md)
- [Using kernels in a Docker container](docs/docker.md) - [Using kernels in a Docker container](docs/docker.md)
- [Kernel requirements](docs/kernel-requirements.md) - [Kernel requirements](docs/kernel-requirements.md)

View File

@ -1,4 +1,4 @@
# Locking kernel versions # Locking kernel/layer versions
Projects that use `setuptools` can lock the kernel versions that should be Projects that use `setuptools` can lock the kernel versions that should be
used. First specify the accepted versions in `pyproject.toml` and make used. First specify the accepted versions in `pyproject.toml` and make
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
**Note:** the lock file is included in the package metadata, so it will only be visible **Note:** the lock file is included in the package metadata, so it will only be visible
to `kernels` after doing an (editable or regular) installation of your project. to `kernels` after doing an (editable or regular) installation of your project.
## Locked kernel layers
Locking is also supported for kernel layers. To use locked layers, register them
with the `LockedLayerRepository` class:
```python
kernel_layer_mapping = {
"SiluAndMul": {
"cuda": LockedLayerRepository(
repo_id="kernels-community/activation",
layer_name="SiluAndMul",
)
}
}
register_kernel_mapping(kernel_layer_mapping)
```
## Pre-downloading locked kernels ## Pre-downloading locked kernels
Locked kernels can be pre-downloaded by running `kernels download .` in your Locked kernels can be pre-downloaded by running `kernels download .` in your

View File

@ -12,11 +12,13 @@ from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from enum import Flag, auto from enum import Flag, auto
from functools import lru_cache from functools import lru_cache
from pathlib import Path
from types import MethodType from types import MethodType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Dict, Dict,
Optional, Optional,
Protocol,
Tuple, Tuple,
Type, Type,
Union, Union,
@ -24,7 +26,7 @@ from typing import (
from ._interval_tree import IntervalTree from ._interval_tree import IntervalTree
from ._versions import select_revision_or_version from ._versions import select_revision_or_version
from .utils import get_kernel from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
@ -114,6 +116,17 @@ class CUDAProperties:
return hash((self.min_capability, self.max_capability)) return hash((self.min_capability, self.max_capability))
class LayerRepositoryProtocol(Protocol):
@property
def layer_name(self) -> str: ...
@property
def repo_id(self) -> str: ...
@property
def revision(self) -> str: ...
class LayerRepository: class LayerRepository:
""" """
Repository and name of a layer. Repository and name of a layer.
@ -173,7 +186,57 @@ class LayerRepository:
return hash((self.layer_name, self.repo_id, self._revision, self._version)) return hash((self.layer_name, self.repo_id, self._revision, self._version))
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {} class LockedLayerRepository:
"""
Repository and name of a layer.
In contrast to `LayerRepository`, this class uses repositories that
are locked inside a project.
"""
def __init__(
self,
repo_id: str,
*,
lockfile: Optional[Path] = None,
layer_name: str,
):
"""
Construct a layer repository.
Args:
repo_id (`str`): The Hub repository containing the layer.
"""
self.repo_id = repo_id
self.lockfile = lockfile
self.layer_name = layer_name
@property
@functools.lru_cache()
def revision(self) -> str:
if self.lockfile is None:
locked_sha = _get_caller_locked_kernel(self.repo_id)
else:
with open(self.lockfile, "r") as f:
locked_sha = _get_locked_kernel(self.repo_id, f.read())
if locked_sha is None:
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
return locked_sha
def __eq__(self, other):
return (
isinstance(other, LockedLayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
)
def __hash__(self):
return hash((self.layer_name, self.repo_id))
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
class _DeviceRepos(ABC): class _DeviceRepos(ABC):
@ -185,10 +248,10 @@ class _DeviceRepos(ABC):
@abstractmethod @abstractmethod
def repos( def repos(
self, self,
) -> Optional[Dict[Mode, LayerRepository]]: ... ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
@abstractmethod @abstractmethod
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]): def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
""" """
Insert a repository for a specific device and mode. Insert a repository for a specific device and mode.
""" """
@ -196,7 +259,7 @@ class _DeviceRepos(ABC):
class _MPSRepos(_DeviceRepos): class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepository] _repos: Dict[Mode, LayerRepositoryProtocol]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -205,10 +268,10 @@ class _MPSRepos(_DeviceRepos):
@property @property
def repos( def repos(
self, self,
) -> Optional[Dict[Mode, LayerRepository]]: ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
return self._repos return self._repos
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]): def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
if device.type != "mps": if device.type != "mps":
raise ValueError(f"Device type must be 'mps', got {device.type}") raise ValueError(f"Device type must be 'mps', got {device.type}")
@ -216,7 +279,7 @@ class _MPSRepos(_DeviceRepos):
class _CUDARepos(_DeviceRepos): class _CUDARepos(_DeviceRepos):
_repos: IntervalTree[Dict[Mode, LayerRepository]] _repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -225,11 +288,11 @@ class _CUDARepos(_DeviceRepos):
@property @property
def repos( def repos(
self, self,
) -> Optional[Dict[Mode, LayerRepository]]: ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
capability = _find_capability() capability = _find_capability()
return self.repos_by_capability.find_smallest_interval(capability) return self.repos_by_capability.find_smallest_interval(capability)
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]): def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
assert device.properties is None or isinstance( assert device.properties is None or isinstance(
device.properties, CUDAProperties device.properties, CUDAProperties
) )
@ -254,7 +317,10 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
def use_kernel_mapping( def use_kernel_mapping(
mapping: Dict[ mapping: Dict[
str, str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]], Dict[
Union[Device, str],
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
],
], ],
*, *,
inherit_mapping: bool = True, inherit_mapping: bool = True,
@ -285,7 +351,10 @@ def use_kernel_mapping(
def register_kernel_mapping( def register_kernel_mapping(
mapping: Dict[ mapping: Dict[
str, str,
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]], Dict[
Union[Device, str],
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
],
], ],
): ):
""" """
@ -318,10 +387,10 @@ def register_kernel_mapping(
Device(type=new_device) if isinstance(new_device, str) else new_device Device(type=new_device) if isinstance(new_device, str) else new_device
) )
if isinstance(new_repo, LayerRepository): if isinstance(new_repo, dict):
kernel_options = {Mode.FALLBACK: new_repo}
else:
kernel_options = new_repo kernel_options = new_repo
else:
kernel_options = {Mode.FALLBACK: new_repo}
feature_repos = device_repo.setdefault(device.type, device.create_repo()) feature_repos = device_repo.setdefault(device.type, device.create_repo())
feature_repos.insert(device, kernel_options) feature_repos.insert(device, kernel_options)
@ -373,10 +442,10 @@ _MODE_FALLBACK_PRIORITY = {
def _select_repository( def _select_repository(
repositories: Dict[Mode, LayerRepository], repositories: Dict[Mode, LayerRepositoryProtocol],
*, *,
mode: Mode, mode: Mode,
) -> Optional[Tuple[LayerRepository, Mode]]: ) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
# Get the fallback priority list for the requested mode # Get the fallback priority list for the requested mode
if mode not in _MODE_FALLBACK_PRIORITY: if mode not in _MODE_FALLBACK_PRIORITY:
raise ValueError(f"Unsupported mode: {mode}") raise ValueError(f"Unsupported mode: {mode}")
@ -647,7 +716,7 @@ def _validate_layer_has_mode(
*, *,
layer_name: str, layer_name: str,
module: Type["nn.Module"], module: Type["nn.Module"],
repo: LayerRepository, repo: LayerRepositoryProtocol,
repo_mode: Mode, repo_mode: Mode,
): ):
""" """
@ -672,7 +741,7 @@ def _validate_layer_has_mode(
def _get_layer_memoize( def _get_layer_memoize(
repo: LayerRepository, module_class: Type["nn.Module"] repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
) -> Type["nn.Module"]: ) -> Type["nn.Module"]:
layer = _CACHED_LAYER.get(repo, None) layer = _CACHED_LAYER.get(repo, None)
if layer is not None: if layer is not None:

View File

@ -0,0 +1,12 @@
[
{
"repo_id": "kernels-test/versions",
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
"variants": {
"torch-universal": {
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
"hash_type": "git_lfs_concat"
}
}
}
]

View File

@ -0,0 +1,2 @@
[tool.kernels.dependencies]
"kernels-test/versions" = ">=0.1.0,<0.2.0"

View File

@ -2,9 +2,17 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import pytest import pytest
import torch.nn as nn
from kernels import load_kernel from kernels import load_kernel
from kernels.cli import download_kernels from kernels.cli import download_kernels
from kernels.layer import (
LockedLayerRepository,
Mode,
kernelize,
use_kernel_forward_from_hub,
use_kernel_mapping,
)
# Mock download arguments class. # Mock download arguments class.
@ -25,3 +33,28 @@ def test_load_locked():
# Also validates that hashing works correctly. # Also validates that hashing works correctly.
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
def test_layer_locked():
project_dir = Path(__file__).parent / "layer_locking"
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
"cuda": LockedLayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
lockfile=project_dir / "kernels.lock",
)
},
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"