Add version support to LayerRepository (#113)

* Add version support to `LayerRepository`

* Remove some docs that do not apply

* Removed unused member variable
This commit is contained in:
Daniël de Kok
2025-07-22 17:02:39 +02:00
committed by GitHub
parent 6d3c6daf20
commit 4a04c005e3
4 changed files with 159 additions and 30 deletions

View File

@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Optional
from huggingface_hub import HfApi
from huggingface_hub.hf_api import GitRefInfo
@ -37,3 +37,16 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
)
return versions[accepted_versions[-1]]
def select_revision_or_version(
repo_id: str, revision: Optional[str], version: Optional[str]
) -> str:
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")
elif revision is None and version is None:
revision = "main"
elif version is not None:
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
assert revision is not None
return revision

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import functools
import inspect
import logging
import os
@ -8,7 +9,7 @@ import warnings
from abc import ABC, abstractmethod
from contextvars import ContextVar
from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Flag, auto
from functools import lru_cache
from types import MethodType
@ -22,6 +23,7 @@ from typing import (
)
from ._interval_tree import IntervalTree
from ._versions import select_revision_or_version
from .utils import get_kernel
if TYPE_CHECKING:
@ -112,30 +114,63 @@ class CUDAProperties:
return hash((self.min_capability, self.max_capability))
@dataclass
class LayerRepository:
"""
Repository and name of a layer.
"""
layer_name: str = field(
metadata={"help": "The name of the layer in the kernel repository."}
)
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
revision: str = field(
default="main", metadata={"help": "The revision of the layer."}
)
def __init__(
self,
repo_id: str,
*,
layer_name: str,
revision: Optional[str] = None,
version: Optional[str] = None,
):
"""
Construct a layer repository.
Args:
repo_id (`str`): The Hub repository containing the layer.
revision (`str`, *optional*, defaults to `"main"`): The specific
revision (branch, tag, or commit) to download.
Cannot be used together with `version`.
version (`str`, *optional*): The kernel version to download. This
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
"""
if revision is not None and version is not None:
raise ValueError(
"Either a revision or a version must be specified, not both."
)
self.repo_id = repo_id
self.layer_name = layer_name
# We are going to resolve these lazily, since we do not want
# to do a network request for every registered LayerRepository.
self._revision = revision
self._version = version
@property
@functools.lru_cache()
def revision(self) -> str:
return select_revision_or_version(
repo_id=self.repo_id, revision=self._revision, version=self._version
)
def __eq__(self, other):
return (
isinstance(other, LayerRepository)
and self.layer_name == other.layer_name
and self.repo_id == other.repo_id
and self.revision == other.revision
and self._revision == other._revision
and self._version == other._version
)
def __hash__(self):
return hash((self.layer_name, self.repo_id, self.revision))
return hash((self.layer_name, self.repo_id, self._revision, self._version))
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}

View File

@ -16,7 +16,7 @@ from typing import Dict, List, Optional, Tuple
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
from kernels._versions import resolve_version_spec_as_ref
from kernels._versions import select_revision_or_version
from kernels.lockfile import KernelLock, VariantLock
@ -209,7 +209,7 @@ def get_kernel(
result = kernel.kernel_function(input_data)
```
"""
revision = _revision_or_version(repo_id, revision, version)
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(repo_id, revision=revision)
return import_from_path(package_name, package_path / package_name / "__init__.py")
@ -240,7 +240,7 @@ def has_kernel(
Returns:
`bool`: `true` if a kernel is avaialble for the current environment.
"""
revision = _revision_or_version(repo_id, revision, version)
revision = select_revision_or_version(repo_id, revision, version)
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
@ -422,16 +422,3 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
def package_name_from_repo_id(repo_id: str) -> str:
return repo_id.split("/")[-1].replace("-", "_")
def _revision_or_version(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
) -> str:
if revision is not None and version is not None:
raise ValueError("Either a revision or a version must be specified, not both.")
elif revision is None and version is None:
revision = "main"
elif version is not None:
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
assert revision is not None
return revision

View File

@ -801,7 +801,8 @@ def test_kernel_modes_cross_fallback():
{
"Linear": {
"cuda": {
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
Mode.TRAINING
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
)
@ -839,7 +840,8 @@ def test_kernel_modes_cross_fallback():
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
Mode.INFERENCE
| Mode.TORCH_COMPILE: LayerRepository(
repo_id="kernels-test/backward-marker-test",
layer_name="LinearBackward",
),
@ -857,3 +859,95 @@ def test_kernel_modes_cross_fallback():
linear(X)
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
assert linear.n_calls == 2
def test_layer_versions():
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
return "0.0.0"
version = Version()
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<1.0.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.2.0"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.1.0,<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
assert version() == "0.1.1"
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.2.0",
)
}
}
):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE)
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
revision="v0.1.0",
version="<1.0.0",
)
}
}
)