mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
get_kernel
: allow Python-style version specifiers (#111)
Use Python-style version specifiers to resolve to tags. E.g., given the presence of the tags `v0.1.0`, `v0.1.1`, and `v0.2.0`, get_kernel("my/kernel", version=">=0.1.0,<0.2.0") would resolve to `v0.1.1`.
This commit is contained in:
39
src/kernels/_versions.py
Normal file
39
src/kernels/_versions.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Dict
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
return versions[accepted_versions[-1]]
|
@ -4,10 +4,8 @@ from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from kernels._versions import resolve_version_spec_as_ref
|
||||
from kernels.compat import tomllib
|
||||
|
||||
|
||||
@ -31,20 +29,6 @@ class KernelLock:
|
||||
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
@ -52,16 +36,7 @@ def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
tag_for_newest = versions[accepted_versions[-1]]
|
||||
tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)
|
||||
|
||||
r = HfApi().repo_info(
|
||||
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
||||
|
@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
from huggingface_hub import file_exists, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
from kernels._versions import resolve_version_spec_as_ref
|
||||
from kernels.lockfile import KernelLock, VariantLock
|
||||
|
||||
|
||||
@ -182,13 +183,31 @@ def install_kernel_all_variants(
|
||||
return repo_path / "build"
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
def get_kernel(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Download and import a kernel from the Hugging Face Hub.
|
||||
|
||||
The kernel is downloaded from the repository `repo_id` at
|
||||
branch/commit/tag `revision`.
|
||||
Load a kernel from the kernel hub.
|
||||
This function downloads a kernel to the local Hugging Face Hub cache
|
||||
directory (if it was not downloaded before) and then loads the kernel.
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
Returns:
|
||||
`ModuleType`: The imported kernel module.
|
||||
Example:
|
||||
```python
|
||||
from kernels import get_kernel
|
||||
kernel = get_kernel("username/my-kernel")
|
||||
result = kernel.kernel_function(input_data)
|
||||
```
|
||||
"""
|
||||
revision = _revision_or_version(repo_id, revision, version)
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
@ -201,11 +220,26 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def has_kernel(repo_id: str, revision: str = "main") -> bool:
|
||||
def has_kernel(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check whether a kernel build exists for the current environment
|
||||
(Torch version and compute framework).
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The Hub repository containing the kernel.
|
||||
revision (`str`, *optional*, defaults to `"main"`): The specific
|
||||
revision (branch, tag, or commit) to download.
|
||||
Cannot be used together with `version`.
|
||||
version (`str`, *optional*): The kernel version to download. This
|
||||
can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
||||
Cannot be used together with `revision`.
|
||||
Returns:
|
||||
`bool`: `true` if a kernel is avaialble for the current environment.
|
||||
"""
|
||||
revision = _revision_or_version(repo_id, revision, version)
|
||||
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
universal_variant = universal_build_variant()
|
||||
@ -386,3 +420,16 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
|
||||
|
||||
def package_name_from_repo_id(repo_id: str) -> str:
|
||||
return repo_id.split("/")[-1].replace("-", "_")
|
||||
|
||||
|
||||
def _revision_or_version(
|
||||
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
||||
) -> str:
|
||||
if revision is not None and version is not None:
|
||||
raise ValueError("Either a revision or a version must be specified, not both.")
|
||||
elif revision is None and version is None:
|
||||
revision = "main"
|
||||
elif version is not None:
|
||||
revision = resolve_version_spec_as_ref(repo_id, version).target_commit
|
||||
assert revision is not None
|
||||
return revision
|
||||
|
@ -91,6 +91,25 @@ def test_has_kernel(kernel_exists):
|
||||
assert has_kernel(repo_id, revision=revision) == kernel
|
||||
|
||||
|
||||
def test_version():
|
||||
kernel = get_kernel("kernels-test/versions")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<1.0.0")
|
||||
assert kernel.version() == "0.2.0"
|
||||
kernel = get_kernel("kernels-test/versions", version="<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
kernel = get_kernel("kernels-test/versions", version=">0.1.0,<0.2.0")
|
||||
assert kernel.version() == "0.1.1"
|
||||
|
||||
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
|
||||
get_kernel("kernels-test/versions", version=">0.2.0")
|
||||
|
||||
with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
|
||||
kernel = get_kernel(
|
||||
"kernels-test/versions", revision="v0.1.0", version="<1.0.0"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.linux_only
|
||||
def test_universal_kernel(universal_kernel):
|
||||
torch.manual_seed(0)
|
||||
|
Reference in New Issue
Block a user