mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 12:33:46 +08:00
Add support for project-wide locking of layers (#114)
This change adds `LockedLayerRepository` as an alternative to `LayerRepository`. `LockedLayerRepository` allows for locking all kernel layers that are used at the project level. Example usage: ``` with use_kernel_mapping( { "SomeLayer": { "cuda": LockedLayerRepository( repo_id="some-org/some-layer", layer_name="SomeLayer", ) }, } ): layer = kernelize(layer, device="cuda", mode=Mode.INFERENCE) ``` This requires that the project has a `pyproject.toml` with kernel version specifications and `kernel.lock` with the locked kernels.
This commit is contained in:
@ -57,7 +57,7 @@ the Hub.
|
||||
## 📚 Documentation
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Locking kernel/layer versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Locking kernel versions
|
||||
# Locking kernel/layer versions
|
||||
|
||||
Projects that use `setuptools` can lock the kernel versions that should be
|
||||
used. First specify the accepted versions in `pyproject.toml` and make
|
||||
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
|
||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||
to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Locked kernel layers
|
||||
|
||||
Locking is also supported for kernel layers. To use locked layers, register them
|
||||
with the `LockedLayerRepository` class:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
|
@ -12,11 +12,13 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from enum import Flag, auto
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -24,7 +26,7 @@ from typing import (
|
||||
|
||||
from ._interval_tree import IntervalTree
|
||||
from ._versions import select_revision_or_version
|
||||
from .utils import get_kernel
|
||||
from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
@ -114,6 +116,17 @@ class CUDAProperties:
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
class LayerRepositoryProtocol(Protocol):
|
||||
@property
|
||||
def layer_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def revision(self) -> str: ...
|
||||
|
||||
|
||||
class LayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
@ -173,7 +186,57 @@ class LayerRepository:
|
||||
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
class LockedLayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
|
||||
In contrast to `LayerRepository`, this class uses repositories that
|
||||
are locked inside a project.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
lockfile: Optional[Path] = None,
|
||||
layer_name: str,
|
||||
):
|
||||
"""
|
||||
Construct a layer repository.
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the layer.
|
||||
"""
|
||||
self.repo_id = repo_id
|
||||
self.lockfile = lockfile
|
||||
self.layer_name = layer_name
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def revision(self) -> str:
|
||||
if self.lockfile is None:
|
||||
locked_sha = _get_caller_locked_kernel(self.repo_id)
|
||||
else:
|
||||
with open(self.lockfile, "r") as f:
|
||||
locked_sha = _get_locked_kernel(self.repo_id, f.read())
|
||||
|
||||
if locked_sha is None:
|
||||
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
|
||||
|
||||
return locked_sha
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, LockedLayerRepository)
|
||||
and self.layer_name == other.layer_name
|
||||
and self.repo_id == other.repo_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.layer_name, self.repo_id))
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||
|
||||
|
||||
class _DeviceRepos(ABC):
|
||||
@ -185,10 +248,10 @@ class _DeviceRepos(ABC):
|
||||
@abstractmethod
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]: ...
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
"""
|
||||
Insert a repository for a specific device and mode.
|
||||
"""
|
||||
@ -196,7 +259,7 @@ class _DeviceRepos(ABC):
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepository]
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -205,10 +268,10 @@ class _MPSRepos(_DeviceRepos):
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "mps":
|
||||
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
||||
|
||||
@ -216,7 +279,7 @@ class _MPSRepos(_DeviceRepos):
|
||||
|
||||
|
||||
class _CUDARepos(_DeviceRepos):
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepository]]
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -225,11 +288,11 @@ class _CUDARepos(_DeviceRepos):
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
capability = _find_capability()
|
||||
return self.repos_by_capability.find_smallest_interval(capability)
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
assert device.properties is None or isinstance(
|
||||
device.properties, CUDAProperties
|
||||
)
|
||||
@ -254,7 +317,10 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
Dict[
|
||||
Union[Device, str],
|
||||
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||
],
|
||||
],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
@ -285,7 +351,10 @@ def use_kernel_mapping(
|
||||
def register_kernel_mapping(
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
Dict[
|
||||
Union[Device, str],
|
||||
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||
],
|
||||
],
|
||||
):
|
||||
"""
|
||||
@ -318,10 +387,10 @@ def register_kernel_mapping(
|
||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||
)
|
||||
|
||||
if isinstance(new_repo, LayerRepository):
|
||||
kernel_options = {Mode.FALLBACK: new_repo}
|
||||
else:
|
||||
if isinstance(new_repo, dict):
|
||||
kernel_options = new_repo
|
||||
else:
|
||||
kernel_options = {Mode.FALLBACK: new_repo}
|
||||
|
||||
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
||||
feature_repos.insert(device, kernel_options)
|
||||
@ -373,10 +442,10 @@ _MODE_FALLBACK_PRIORITY = {
|
||||
|
||||
|
||||
def _select_repository(
|
||||
repositories: Dict[Mode, LayerRepository],
|
||||
repositories: Dict[Mode, LayerRepositoryProtocol],
|
||||
*,
|
||||
mode: Mode,
|
||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
||||
) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
|
||||
# Get the fallback priority list for the requested mode
|
||||
if mode not in _MODE_FALLBACK_PRIORITY:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
@ -647,7 +716,7 @@ def _validate_layer_has_mode(
|
||||
*,
|
||||
layer_name: str,
|
||||
module: Type["nn.Module"],
|
||||
repo: LayerRepository,
|
||||
repo: LayerRepositoryProtocol,
|
||||
repo_mode: Mode,
|
||||
):
|
||||
"""
|
||||
@ -672,7 +741,7 @@ def _validate_layer_has_mode(
|
||||
|
||||
|
||||
def _get_layer_memoize(
|
||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
||||
repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
|
||||
) -> Type["nn.Module"]:
|
||||
layer = _CACHED_LAYER.get(repo, None)
|
||||
if layer is not None:
|
||||
|
12
tests/layer_locking/kernels.lock
Normal file
12
tests/layer_locking/kernels.lock
Normal file
@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-test/versions",
|
||||
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
|
||||
"variants": {
|
||||
"torch-universal": {
|
||||
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
2
tests/layer_locking/pyproject.toml
Normal file
2
tests/layer_locking/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-test/versions" = ">=0.1.0,<0.2.0"
|
@ -2,9 +2,17 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
||||
from kernels import load_kernel
|
||||
from kernels.cli import download_kernels
|
||||
from kernels.layer import (
|
||||
LockedLayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
use_kernel_forward_from_hub,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
|
||||
|
||||
# Mock download arguments class.
|
||||
@ -25,3 +33,28 @@ def test_load_locked():
|
||||
# Also validates that hashing works correctly.
|
||||
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||
|
||||
|
||||
def test_layer_locked():
|
||||
project_dir = Path(__file__).parent / "layer_locking"
|
||||
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
version = Version()
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
lockfile=project_dir / "kernels.lock",
|
||||
)
|
||||
},
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
Reference in New Issue
Block a user