mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 20:56:31 +08:00
Add support for project-wide locking of layers (#114)
This change adds `LockedLayerRepository` as an alternative to `LayerRepository`. `LockedLayerRepository` allows for locking all kernel layers that are used at the project level. Example usage: ``` with use_kernel_mapping( { "SomeLayer": { "cuda": LockedLayerRepository( repo_id="some-org/some-layer", layer_name="SomeLayer", ) }, } ): layer = kernelize(layer, device="cuda", mode=Mode.INFERENCE) ``` This requires that the project has a `pyproject.toml` with kernel version specifications and `kernel.lock` with the locked kernels.
This commit is contained in:
@ -57,7 +57,7 @@ the Hub.
|
|||||||
## 📚 Documentation
|
## 📚 Documentation
|
||||||
|
|
||||||
- [Using layers](docs/layers.md)
|
- [Using layers](docs/layers.md)
|
||||||
- [Locking kernel versions](docs/locking.md)
|
- [Locking kernel/layer versions](docs/locking.md)
|
||||||
- [Environment variables](docs/env.md)
|
- [Environment variables](docs/env.md)
|
||||||
- [Using kernels in a Docker container](docs/docker.md)
|
- [Using kernels in a Docker container](docs/docker.md)
|
||||||
- [Kernel requirements](docs/kernel-requirements.md)
|
- [Kernel requirements](docs/kernel-requirements.md)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Locking kernel versions
|
# Locking kernel/layer versions
|
||||||
|
|
||||||
Projects that use `setuptools` can lock the kernel versions that should be
|
Projects that use `setuptools` can lock the kernel versions that should be
|
||||||
used. First specify the accepted versions in `pyproject.toml` and make
|
used. First specify the accepted versions in `pyproject.toml` and make
|
||||||
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
|
|||||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||||
to `kernels` after doing an (editable or regular) installation of your project.
|
to `kernels` after doing an (editable or regular) installation of your project.
|
||||||
|
|
||||||
|
## Locked kernel layers
|
||||||
|
|
||||||
|
Locking is also supported for kernel layers. To use locked layers, register them
|
||||||
|
with the `LockedLayerRepository` class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_layer_mapping = {
|
||||||
|
"SiluAndMul": {
|
||||||
|
"cuda": LockedLayerRepository(
|
||||||
|
repo_id="kernels-community/activation",
|
||||||
|
layer_name="SiluAndMul",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
register_kernel_mapping(kernel_layer_mapping)
|
||||||
|
```
|
||||||
|
|
||||||
## Pre-downloading locked kernels
|
## Pre-downloading locked kernels
|
||||||
|
|
||||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||||
|
@ -12,11 +12,13 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Flag, auto
|
from enum import Flag, auto
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Dict,
|
Dict,
|
||||||
Optional,
|
Optional,
|
||||||
|
Protocol,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
@ -24,7 +26,7 @@ from typing import (
|
|||||||
|
|
||||||
from ._interval_tree import IntervalTree
|
from ._interval_tree import IntervalTree
|
||||||
from ._versions import select_revision_or_version
|
from ._versions import select_revision_or_version
|
||||||
from .utils import get_kernel
|
from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
@ -114,6 +116,17 @@ class CUDAProperties:
|
|||||||
return hash((self.min_capability, self.max_capability))
|
return hash((self.min_capability, self.max_capability))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerRepositoryProtocol(Protocol):
|
||||||
|
@property
|
||||||
|
def layer_name(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def revision(self) -> str: ...
|
||||||
|
|
||||||
|
|
||||||
class LayerRepository:
|
class LayerRepository:
|
||||||
"""
|
"""
|
||||||
Repository and name of a layer.
|
Repository and name of a layer.
|
||||||
@ -173,7 +186,57 @@ class LayerRepository:
|
|||||||
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
||||||
|
|
||||||
|
|
||||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
class LockedLayerRepository:
|
||||||
|
"""
|
||||||
|
Repository and name of a layer.
|
||||||
|
|
||||||
|
In contrast to `LayerRepository`, this class uses repositories that
|
||||||
|
are locked inside a project.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
lockfile: Optional[Path] = None,
|
||||||
|
layer_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Construct a layer repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (`str`): The Hub repository containing the layer.
|
||||||
|
"""
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.lockfile = lockfile
|
||||||
|
self.layer_name = layer_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
@functools.lru_cache()
|
||||||
|
def revision(self) -> str:
|
||||||
|
if self.lockfile is None:
|
||||||
|
locked_sha = _get_caller_locked_kernel(self.repo_id)
|
||||||
|
else:
|
||||||
|
with open(self.lockfile, "r") as f:
|
||||||
|
locked_sha = _get_locked_kernel(self.repo_id, f.read())
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
|
||||||
|
|
||||||
|
return locked_sha
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
isinstance(other, LockedLayerRepository)
|
||||||
|
and self.layer_name == other.layer_name
|
||||||
|
and self.repo_id == other.repo_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash((self.layer_name, self.repo_id))
|
||||||
|
|
||||||
|
|
||||||
|
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||||
|
|
||||||
|
|
||||||
class _DeviceRepos(ABC):
|
class _DeviceRepos(ABC):
|
||||||
@ -185,10 +248,10 @@ class _DeviceRepos(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def repos(
|
def repos(
|
||||||
self,
|
self,
|
||||||
) -> Optional[Dict[Mode, LayerRepository]]: ...
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
"""
|
"""
|
||||||
Insert a repository for a specific device and mode.
|
Insert a repository for a specific device and mode.
|
||||||
"""
|
"""
|
||||||
@ -196,7 +259,7 @@ class _DeviceRepos(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class _MPSRepos(_DeviceRepos):
|
class _MPSRepos(_DeviceRepos):
|
||||||
_repos: Dict[Mode, LayerRepository]
|
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -205,10 +268,10 @@ class _MPSRepos(_DeviceRepos):
|
|||||||
@property
|
@property
|
||||||
def repos(
|
def repos(
|
||||||
self,
|
self,
|
||||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||||
return self._repos
|
return self._repos
|
||||||
|
|
||||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
if device.type != "mps":
|
if device.type != "mps":
|
||||||
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
||||||
|
|
||||||
@ -216,7 +279,7 @@ class _MPSRepos(_DeviceRepos):
|
|||||||
|
|
||||||
|
|
||||||
class _CUDARepos(_DeviceRepos):
|
class _CUDARepos(_DeviceRepos):
|
||||||
_repos: IntervalTree[Dict[Mode, LayerRepository]]
|
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -225,11 +288,11 @@ class _CUDARepos(_DeviceRepos):
|
|||||||
@property
|
@property
|
||||||
def repos(
|
def repos(
|
||||||
self,
|
self,
|
||||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||||
capability = _find_capability()
|
capability = _find_capability()
|
||||||
return self.repos_by_capability.find_smallest_interval(capability)
|
return self.repos_by_capability.find_smallest_interval(capability)
|
||||||
|
|
||||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||||
assert device.properties is None or isinstance(
|
assert device.properties is None or isinstance(
|
||||||
device.properties, CUDAProperties
|
device.properties, CUDAProperties
|
||||||
)
|
)
|
||||||
@ -254,7 +317,10 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
|||||||
def use_kernel_mapping(
|
def use_kernel_mapping(
|
||||||
mapping: Dict[
|
mapping: Dict[
|
||||||
str,
|
str,
|
||||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
Dict[
|
||||||
|
Union[Device, str],
|
||||||
|
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||||
|
],
|
||||||
],
|
],
|
||||||
*,
|
*,
|
||||||
inherit_mapping: bool = True,
|
inherit_mapping: bool = True,
|
||||||
@ -285,7 +351,10 @@ def use_kernel_mapping(
|
|||||||
def register_kernel_mapping(
|
def register_kernel_mapping(
|
||||||
mapping: Dict[
|
mapping: Dict[
|
||||||
str,
|
str,
|
||||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
Dict[
|
||||||
|
Union[Device, str],
|
||||||
|
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||||
|
],
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -318,10 +387,10 @@ def register_kernel_mapping(
|
|||||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(new_repo, LayerRepository):
|
if isinstance(new_repo, dict):
|
||||||
kernel_options = {Mode.FALLBACK: new_repo}
|
|
||||||
else:
|
|
||||||
kernel_options = new_repo
|
kernel_options = new_repo
|
||||||
|
else:
|
||||||
|
kernel_options = {Mode.FALLBACK: new_repo}
|
||||||
|
|
||||||
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
||||||
feature_repos.insert(device, kernel_options)
|
feature_repos.insert(device, kernel_options)
|
||||||
@ -373,10 +442,10 @@ _MODE_FALLBACK_PRIORITY = {
|
|||||||
|
|
||||||
|
|
||||||
def _select_repository(
|
def _select_repository(
|
||||||
repositories: Dict[Mode, LayerRepository],
|
repositories: Dict[Mode, LayerRepositoryProtocol],
|
||||||
*,
|
*,
|
||||||
mode: Mode,
|
mode: Mode,
|
||||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
|
||||||
# Get the fallback priority list for the requested mode
|
# Get the fallback priority list for the requested mode
|
||||||
if mode not in _MODE_FALLBACK_PRIORITY:
|
if mode not in _MODE_FALLBACK_PRIORITY:
|
||||||
raise ValueError(f"Unsupported mode: {mode}")
|
raise ValueError(f"Unsupported mode: {mode}")
|
||||||
@ -647,7 +716,7 @@ def _validate_layer_has_mode(
|
|||||||
*,
|
*,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
module: Type["nn.Module"],
|
module: Type["nn.Module"],
|
||||||
repo: LayerRepository,
|
repo: LayerRepositoryProtocol,
|
||||||
repo_mode: Mode,
|
repo_mode: Mode,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -672,7 +741,7 @@ def _validate_layer_has_mode(
|
|||||||
|
|
||||||
|
|
||||||
def _get_layer_memoize(
|
def _get_layer_memoize(
|
||||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
|
||||||
) -> Type["nn.Module"]:
|
) -> Type["nn.Module"]:
|
||||||
layer = _CACHED_LAYER.get(repo, None)
|
layer = _CACHED_LAYER.get(repo, None)
|
||||||
if layer is not None:
|
if layer is not None:
|
||||||
|
12
tests/layer_locking/kernels.lock
Normal file
12
tests/layer_locking/kernels.lock
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-test/versions",
|
||||||
|
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
|
||||||
|
"variants": {
|
||||||
|
"torch-universal": {
|
||||||
|
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
2
tests/layer_locking/pyproject.toml
Normal file
2
tests/layer_locking/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-test/versions" = ">=0.1.0,<0.2.0"
|
@ -2,9 +2,17 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from kernels import load_kernel
|
from kernels import load_kernel
|
||||||
from kernels.cli import download_kernels
|
from kernels.cli import download_kernels
|
||||||
|
from kernels.layer import (
|
||||||
|
LockedLayerRepository,
|
||||||
|
Mode,
|
||||||
|
kernelize,
|
||||||
|
use_kernel_forward_from_hub,
|
||||||
|
use_kernel_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Mock download arguments class.
|
# Mock download arguments class.
|
||||||
@ -25,3 +33,28 @@ def test_load_locked():
|
|||||||
# Also validates that hashing works correctly.
|
# Also validates that hashing works correctly.
|
||||||
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||||
|
|
||||||
|
|
||||||
|
def test_layer_locked():
|
||||||
|
project_dir = Path(__file__).parent / "layer_locking"
|
||||||
|
|
||||||
|
@use_kernel_forward_from_hub("Version")
|
||||||
|
class Version(nn.Module):
|
||||||
|
def forward(self) -> str:
|
||||||
|
return "0.0.0"
|
||||||
|
|
||||||
|
version = Version()
|
||||||
|
|
||||||
|
with use_kernel_mapping(
|
||||||
|
{
|
||||||
|
"Version": {
|
||||||
|
"cuda": LockedLayerRepository(
|
||||||
|
repo_id="kernels-test/versions",
|
||||||
|
layer_name="Version",
|
||||||
|
lockfile=project_dir / "kernels.lock",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
):
|
||||||
|
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||||
|
assert version() == "0.1.1"
|
||||||
|
Reference in New Issue
Block a user