mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 05:30:30 +08:00
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
0429131630 | |||
967ac581b8 | |||
81088d44e8 | |||
4a04c005e3 | |||
6d3c6daf20 | |||
071900fd69 |
@ -57,7 +57,7 @@ the Hub.
|
||||
## 📚 Documentation
|
||||
|
||||
- [Using layers](docs/layers.md)
|
||||
- [Locking kernel versions](docs/locking.md)
|
||||
- [Locking kernel/layer versions](docs/locking.md)
|
||||
- [Environment variables](docs/env.md)
|
||||
- [Using kernels in a Docker container](docs/docker.md)
|
||||
- [Kernel requirements](docs/kernel-requirements.md)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Locking kernel versions
|
||||
# Locking kernel/layer versions
|
||||
|
||||
Projects that use `setuptools` can lock the kernel versions that should be
|
||||
used. First specify the accepted versions in `pyproject.toml` and make
|
||||
@ -26,6 +26,24 @@ activation = get_locked_kernel("kernels-community/activation")
|
||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||
to `kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Locked kernel layers
|
||||
|
||||
Locking is also supported for kernel layers. To use locked layers, register them
|
||||
with the `LockedLayerRepository` class:
|
||||
|
||||
```python
|
||||
kernel_layer_mapping = {
|
||||
"SiluAndMul": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-community/activation",
|
||||
layer_name="SiluAndMul",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
register_kernel_mapping(kernel_layer_mapping)
|
||||
```
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `kernels download .` in your
|
||||
|
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "kernels"
|
||||
version = "0.8.0.dev0"
|
||||
version = "0.8.1"
|
||||
description = "Download compute kernels"
|
||||
authors = [
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
|
52
src/kernels/_versions.py
Normal file
52
src/kernels/_versions.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
return versions[accepted_versions[-1]]
|
||||
|
||||
|
||||
def select_revision_or_version(
|
||||
repo_id: str, revision: Optional[str], version: Optional[str]
|
||||
) -> str:
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||
elif revision is None and version is None:
|
||||
revision = "main"
|
||||
elif version is not None:
|
||||
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||
assert revision is not None
|
||||
return revision
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@ -8,21 +9,24 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Flag, auto
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from ._interval_tree import IntervalTree
|
||||
from .utils import get_kernel
|
||||
from ._versions import select_revision_or_version
|
||||
from .utils import _get_caller_locked_kernel, _get_locked_kernel, get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
@ -112,33 +116,127 @@ class CUDAProperties:
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerRepositoryProtocol(Protocol):
|
||||
@property
|
||||
def layer_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def repo_id(self) -> str: ...
|
||||
|
||||
@property
|
||||
def revision(self) -> str: ...
|
||||
|
||||
|
||||
class LayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
"""
|
||||
|
||||
layer_name: str = field(
|
||||
metadata={"help": "The name of the layer in the kernel repository."}
|
||||
)
|
||||
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
|
||||
revision: str = field(
|
||||
default="main", metadata={"help": "The revision of the layer."}
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
layer_name: str,
|
||||
revision: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Construct a layer repository.
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the layer.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
"""
|
||||
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError(
|
||||
"Either a revision or a version must be specified, not both."
|
||||
)
|
||||
|
||||
self.repo_id = repo_id
|
||||
self.layer_name = layer_name
|
||||
|
||||
# We are going to resolve these lazily, since we do not want
|
||||
# to do a network request for every registered LayerRepository.
|
||||
self._revision = revision
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def revision(self) -> str:
|
||||
return select_revision_or_version(
|
||||
repo_id=self.repo_id, revision=self._revision, version=self._version
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, LayerRepository)
|
||||
and self.layer_name == other.layer_name
|
||||
and self.repo_id == other.repo_id
|
||||
and self.revision == other.revision
|
||||
and self._revision == other._revision
|
||||
and self._version == other._version
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.layer_name, self.repo_id, self.revision))
|
||||
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
class LockedLayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
|
||||
In contrast to `LayerRepository`, this class uses repositories that
|
||||
are locked inside a project.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
lockfile: Optional[Path] = None,
|
||||
layer_name: str,
|
||||
):
|
||||
"""
|
||||
Construct a layer repository.
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the layer.
|
||||
"""
|
||||
self.repo_id = repo_id
|
||||
self.lockfile = lockfile
|
||||
self.layer_name = layer_name
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def revision(self) -> str:
|
||||
if self.lockfile is None:
|
||||
locked_sha = _get_caller_locked_kernel(self.repo_id)
|
||||
else:
|
||||
with open(self.lockfile, "r") as f:
|
||||
locked_sha = _get_locked_kernel(self.repo_id, f.read())
|
||||
|
||||
if locked_sha is None:
|
||||
raise ValueError(f"Kernel `{self.repo_id}` is not locked")
|
||||
|
||||
return locked_sha
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, LockedLayerRepository)
|
||||
and self.layer_name == other.layer_name
|
||||
and self.repo_id == other.repo_id
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.layer_name, self.repo_id))
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
|
||||
|
||||
|
||||
class _DeviceRepos(ABC):
|
||||
@ -150,10 +248,10 @@ class _DeviceRepos(ABC):
|
||||
@abstractmethod
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]: ...
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ...
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
"""
|
||||
Insert a repository for a specific device and mode.
|
||||
"""
|
||||
@ -161,7 +259,7 @@ class _DeviceRepos(ABC):
|
||||
|
||||
|
||||
class _MPSRepos(_DeviceRepos):
|
||||
_repos: Dict[Mode, LayerRepository]
|
||||
_repos: Dict[Mode, LayerRepositoryProtocol]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -170,10 +268,10 @@ class _MPSRepos(_DeviceRepos):
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
return self._repos
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
if device.type != "mps":
|
||||
raise ValueError(f"Device type must be 'mps', got {device.type}")
|
||||
|
||||
@ -181,7 +279,7 @@ class _MPSRepos(_DeviceRepos):
|
||||
|
||||
|
||||
class _CUDARepos(_DeviceRepos):
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepository]]
|
||||
_repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -190,11 +288,11 @@ class _CUDARepos(_DeviceRepos):
|
||||
@property
|
||||
def repos(
|
||||
self,
|
||||
) -> Optional[Dict[Mode, LayerRepository]]:
|
||||
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
|
||||
capability = _find_capability()
|
||||
return self.repos_by_capability.find_smallest_interval(capability)
|
||||
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepository]):
|
||||
def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
|
||||
assert device.properties is None or isinstance(
|
||||
device.properties, CUDAProperties
|
||||
)
|
||||
@ -219,7 +317,10 @@ _KERNEL_MAPPING: ContextVar[Dict[str, Dict[str, _DeviceRepos]]] = ContextVar(
|
||||
def use_kernel_mapping(
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
Dict[
|
||||
Union[Device, str],
|
||||
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||
],
|
||||
],
|
||||
*,
|
||||
inherit_mapping: bool = True,
|
||||
@ -250,7 +351,10 @@ def use_kernel_mapping(
|
||||
def register_kernel_mapping(
|
||||
mapping: Dict[
|
||||
str,
|
||||
Dict[Union[Device, str], Union[LayerRepository, Dict[Mode, LayerRepository]]],
|
||||
Dict[
|
||||
Union[Device, str],
|
||||
Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]],
|
||||
],
|
||||
],
|
||||
):
|
||||
"""
|
||||
@ -283,10 +387,10 @@ def register_kernel_mapping(
|
||||
Device(type=new_device) if isinstance(new_device, str) else new_device
|
||||
)
|
||||
|
||||
if isinstance(new_repo, LayerRepository):
|
||||
kernel_options = {Mode.FALLBACK: new_repo}
|
||||
else:
|
||||
if isinstance(new_repo, dict):
|
||||
kernel_options = new_repo
|
||||
else:
|
||||
kernel_options = {Mode.FALLBACK: new_repo}
|
||||
|
||||
feature_repos = device_repo.setdefault(device.type, device.create_repo())
|
||||
feature_repos.insert(device, kernel_options)
|
||||
@ -338,10 +442,10 @@ _MODE_FALLBACK_PRIORITY = {
|
||||
|
||||
|
||||
def _select_repository(
|
||||
repositories: Dict[Mode, LayerRepository],
|
||||
repositories: Dict[Mode, LayerRepositoryProtocol],
|
||||
*,
|
||||
mode: Mode,
|
||||
) -> Optional[Tuple[LayerRepository, Mode]]:
|
||||
) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]:
|
||||
# Get the fallback priority list for the requested mode
|
||||
if mode not in _MODE_FALLBACK_PRIORITY:
|
||||
raise ValueError(f"Unsupported mode: {mode}")
|
||||
@ -612,7 +716,7 @@ def _validate_layer_has_mode(
|
||||
*,
|
||||
layer_name: str,
|
||||
module: Type["nn.Module"],
|
||||
repo: LayerRepository,
|
||||
repo: LayerRepositoryProtocol,
|
||||
repo_mode: Mode,
|
||||
):
|
||||
"""
|
||||
@ -637,7 +741,7 @@ def _validate_layer_has_mode(
|
||||
|
||||
|
||||
def _get_layer_memoize(
|
||||
repo: LayerRepository, module_class: Type["nn.Module"]
|
||||
repo: LayerRepositoryProtocol, module_class: Type["nn.Module"]
|
||||
) -> Type["nn.Module"]:
|
||||
layer = _CACHED_LAYER.get(repo, None)
|
||||
if layer is not None:
|
||||
|
@ -4,10 +4,8 @@ from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from kernels._versions import resolve_version_spec_as_ref
|
||||
from kernels.compat import tomllib
|
||||
|
||||
|
||||
@ -31,20 +29,6 @@ class KernelLock:
|
||||
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
tag_for_newest = versions[accepted_versions[-1]]
|
||||
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
|
||||
|
||||
r = HfApi().repo_info(
|
||||
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
||||
|
@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels._versions import select_revision_or_version
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
|
||||
@ -45,9 +46,11 @@ def build_variant() -> str:
|
||||
compute_framework = f"rocm{rocm_version.major}{rocm_version.minor}"
|
||||
elif torch.backends.mps.is_available():
|
||||
compute_framework = "metal"
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
compute_framework = "xpu"
|
||||
else:
|
||||
raise AssertionError(
|
||||
"Torch was not compiled with CUDA, Metal, or ROCm enabled."
|
||||
"Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled."
|
||||
)
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
@ -182,13 +185,31 @@ def install_kernel_all_variants(
|
||||
return repo_path / "build"
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
def get_kernel(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Download and import a kernel from the Hugging Face Hub.
|
||||
|
||||
The kernel is downloaded from the repository `repo_id` at
|
||||
branch/commit/tag `revision`.
|
||||
Load a kernel from the kernel hub.
|
||||
This function downloads a kernel to the local Hugging Face Hub cache
|
||||
directory (if it was not downloaded before) and then loads the kernel.
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
Example:
|
||||
```python
|
||||
from kernels import get_kernel
|
||||
kernel = get_kernel("username/my-kernel")
|
||||
result = kernel.kernel_function(input_data)
|
||||
```
|
||||
"""
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
@ -201,11 +222,26 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
def has_kernel(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
Returns:
|
||||
`bool`: `true` if a kernel is avaialble for the current environment.
|
||||
"""
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
|
12
tests/layer_locking/kernels.lock
Normal file
12
tests/layer_locking/kernels.lock
Normal file
@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"repo_id": "kernels-test/versions",
|
||||
"sha": "dc142fd6c9920c993d32be6358b78957c58681c3",
|
||||
"variants": {
|
||||
"torch-universal": {
|
||||
"hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c",
|
||||
"hash_type": "git_lfs_concat"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
2
tests/layer_locking/pyproject.toml
Normal file
2
tests/layer_locking/pyproject.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-test/versions" = ">=0.1.0,<0.2.0"
|
@ -91,6 +91,25 @@ def test_has_kernel(kernel_exists):
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_version():
|
||||
kernel = get_kernel("kernels-test/versions")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
get_kernel("kernels-test/versions", version=">0.2.0")
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
kernel = get_kernel(
|
||||
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
|
@ -2,9 +2,17 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
|
||||
from kernels import load_kernel
|
||||
from kernels.cli import download_kernels
|
||||
from kernels.layer import (
|
||||
LockedLayerRepository,
|
||||
Mode,
|
||||
kernelize,
|
||||
use_kernel_forward_from_hub,
|
||||
use_kernel_mapping,
|
||||
)
|
||||
|
||||
|
||||
# Mock download arguments class.
|
||||
@ -25,3 +33,28 @@ def test_load_locked():
|
||||
# Also validates that hashing works correctly.
|
||||
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||
load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock")
|
||||
|
||||
|
||||
def test_layer_locked():
|
||||
project_dir = Path(__file__).parent / "layer_locking"
|
||||
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
version = Version()
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
"cuda": LockedLayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
lockfile=project_dir / "kernels.lock",
|
||||
)
|
||||
},
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
@ -801,7 +801,8 @@ def test_kernel_modes_cross_fallback():
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
@ -839,7 +840,8 @@ def test_kernel_modes_cross_fallback():
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
@ -857,3 +859,95 @@ def test_kernel_modes_cross_fallback():
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
def test_layer_versions():
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
version = Version()
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<1.0.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.1.0,<0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
revision="v0.1.0",
|
||||
version="<1.0.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
Reference in New Issue
Block a user