mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Add version support to LayerRepository
(#113)
* Add version support to `LayerRepository` * Remove some docs that do not apply * Removed unused member variable
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
@ -37,3 +37,16 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
|
||||
)
|
||||
|
||||
return versions[accepted_versions[-1]]
|
||||
|
||||
|
||||
def select_revision_or_version(
|
||||
repo_id: str, revision: Optional[str], version: Optional[str]
|
||||
) -> str:
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||
elif revision is None and version is None:
|
||||
revision = "main"
|
||||
elif version is not None:
|
||||
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||
assert revision is not None
|
||||
return revision
|
||||
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@ -8,7 +9,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextvars import ContextVar
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Flag, auto
|
||||
from functools import lru_cache
|
||||
from types import MethodType
|
||||
@ -22,6 +23,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from ._interval_tree import IntervalTree
|
||||
from ._versions import select_revision_or_version
|
||||
from .utils import get_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -112,30 +114,63 @@ class CUDAProperties:
|
||||
return hash((self.min_capability, self.max_capability))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerRepository:
|
||||
"""
|
||||
Repository and name of a layer.
|
||||
"""
|
||||
|
||||
layer_name: str = field(
|
||||
metadata={"help": "The name of the layer in the kernel repository."}
|
||||
)
|
||||
repo_id: str = field(metadata={"help": "The kernel hub repository with the layer."})
|
||||
revision: str = field(
|
||||
default="main", metadata={"help": "The revision of the layer."}
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
layer_name: str,
|
||||
revision: Optional[str] = None,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Construct a layer repository.
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the layer.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
"""
|
||||
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError(
|
||||
"Either a revision or a version must be specified, not both."
|
||||
)
|
||||
|
||||
self.repo_id = repo_id
|
||||
self.layer_name = layer_name
|
||||
|
||||
# We are going to resolve these lazily, since we do not want
|
||||
# to do a network request for every registered LayerRepository.
|
||||
self._revision = revision
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
@functools.lru_cache()
|
||||
def revision(self) -> str:
|
||||
return select_revision_or_version(
|
||||
repo_id=self.repo_id, revision=self._revision, version=self._version
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, LayerRepository)
|
||||
and self.layer_name == other.layer_name
|
||||
and self.repo_id == other.repo_id
|
||||
and self.revision == other.revision
|
||||
and self._revision == other._revision
|
||||
and self._version == other._version
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.layer_name, self.repo_id, self.revision))
|
||||
return hash((self.layer_name, self.repo_id, self._revision, self._version))
|
||||
|
||||
|
||||
_CACHED_LAYER: Dict[LayerRepository, Type["nn.Module"]] = {}
|
||||
|
@ -16,7 +16,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels._versions import resolve_version_spec_as_ref
|
||||
from kernels._versions import select_revision_or_version
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
|
||||
@ -209,7 +209,7 @@ def get_kernel(
|
||||
result = kernel.kernel_function(input_data)
|
||||
```
|
||||
"""
|
||||
revision = _revision_or_version(repo_id, revision, version)
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
@ -240,7 +240,7 @@ def has_kernel(
|
||||
Returns:
|
||||
`bool`: `true` if a kernel is avaialble for the current environment.
|
||||
"""
|
||||
revision = _revision_or_version(repo_id, revision, version)
|
||||
revision = select_revision_or_version(repo_id, revision, version)
|
||||
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
@ -422,16 +422,3 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
|
||||
|
||||
def package_name_from_repo_id(repo_id: str) -> str:
|
||||
return repo_id.split("/")[-1].replace("-", "_")
|
||||
|
||||
|
||||
def _revision_or_version(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> str:
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||
elif revision is None and version is None:
|
||||
revision = "main"
|
||||
elif version is not None:
|
||||
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||
assert revision is not None
|
||||
return revision
|
||||
|
@ -801,7 +801,8 @@ def test_kernel_modes_cross_fallback():
|
||||
{
|
||||
"Linear": {
|
||||
"cuda": {
|
||||
Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository(
|
||||
Mode.TRAINING
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
)
|
||||
@ -839,7 +840,8 @@ def test_kernel_modes_cross_fallback():
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
|
||||
Mode.INFERENCE
|
||||
| Mode.TORCH_COMPILE: LayerRepository(
|
||||
repo_id="kernels-test/backward-marker-test",
|
||||
layer_name="LinearBackward",
|
||||
),
|
||||
@ -857,3 +859,95 @@ def test_kernel_modes_cross_fallback():
|
||||
linear(X)
|
||||
# TRAINING | TORCH_COMPILE should NOT fall back to inference kernels, use original
|
||||
assert linear.n_calls == 2
|
||||
|
||||
|
||||
def test_layer_versions():
|
||||
@use_kernel_forward_from_hub("Version")
|
||||
class Version(nn.Module):
|
||||
def forward(self) -> str:
|
||||
return "0.0.0"
|
||||
|
||||
version = Version()
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<1.0.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.2.0"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version="<0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.1.0,<0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
assert version() == "0.1.1"
|
||||
|
||||
with use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
version=">0.2.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
):
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
kernelize(version, device="cuda", mode=Mode.INFERENCE)
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
use_kernel_mapping(
|
||||
{
|
||||
"Version": {
|
||||
Device(type="cuda"): LayerRepository(
|
||||
repo_id="kernels-test/versions",
|
||||
layer_name="Version",
|
||||
revision="v0.1.0",
|
||||
version="<1.0.0",
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
Reference in New Issue
Block a user