mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 21:10:02 +08:00
Add a bunch of cleanups (#36)
* Remove old build backend * Add types, use `Path` where possible * Remove unused `get_metadata` function This function is also problematic, because it assumes that `build.toml` is always present.
This commit is contained in:
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@ -47,5 +47,8 @@ jobs:
|
||||
- name: Install setuptools for Triton-based test
|
||||
run: uv pip install setuptools
|
||||
|
||||
- name: Check typing
|
||||
run: uv run mypy src/kernels
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
||||
|
@ -23,6 +23,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"mypy == 1.14.1",
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
@ -34,7 +35,3 @@ kernels = "kernels.cli:main"
|
||||
[project.entry-points."egg_info.writers"]
|
||||
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||
|
||||
#[build-system]
|
||||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
|
||||
#build-backend = "kernels.build"
|
||||
#backend-path = ["src"]
|
||||
|
@ -1,144 +0,0 @@
|
||||
"""
|
||||
Python shims for the PEP 517 and PEP 660 build backend.
|
||||
|
||||
Major imports in this module are required to be lazy:
|
||||
```
|
||||
$ hyperfine \
|
||||
"/usr/bin/python3 -c \"print('hi')\"" \
|
||||
"/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\""
|
||||
Base: Time (mean ± σ): 11.0 ms ± 1.7 ms [User: 8.5 ms, System: 2.5 ms]
|
||||
With import: Time (mean ± σ): 15.2 ms ± 2.0 ms [User: 12.3 ms, System: 2.9 ms]
|
||||
Base 1.38 ± 0.28 times faster than with import
|
||||
```
|
||||
|
||||
The same thing goes for the typing module, so we use Python 3.10 type annotations that
|
||||
don't require importing typing but then quote them so earlier Python version ignore
|
||||
them while IDEs and type checker can see through the quotes.
|
||||
"""
|
||||
|
||||
from kernels.compat import tomllib
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence # noqa:I001
|
||||
from typing import Any # noqa:I001
|
||||
|
||||
|
||||
def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
|
||||
import sys
|
||||
|
||||
if config_settings:
|
||||
print("Warning: Config settings are not supported", file=sys.stderr)
|
||||
|
||||
|
||||
def call(
|
||||
args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None
|
||||
) -> str:
|
||||
"""Invoke a uv subprocess and return the filename from stdout."""
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
warn_config_settings(config_settings)
|
||||
# Unlike `find_uv_bin`, this mechanism must work according to PEP 517
|
||||
import os
|
||||
|
||||
cwd = os.getcwd()
|
||||
filename = os.path.join(cwd, "pyproject.toml")
|
||||
with open(filename, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
for kernel, _ in (
|
||||
data.get("tool", {}).get("kernels", {}).get("dependencies", {}).items()
|
||||
):
|
||||
from kernels.utils import install_kernel
|
||||
|
||||
install_kernel(kernel, revision="main")
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin is None:
|
||||
raise RuntimeError("uv was not properly installed")
|
||||
# Forward stderr, capture stdout for the filename
|
||||
result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE)
|
||||
if result.returncode != 0:
|
||||
sys.exit(result.returncode)
|
||||
# If there was extra stdout, forward it (there should not be extra stdout)
|
||||
stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True)
|
||||
sys.stdout.writelines(stdout[:-1])
|
||||
# Fail explicitly instead of an irrelevant stacktrace
|
||||
if not stdout:
|
||||
print("uv subprocess did not return a filename on stdout", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return stdout[-1].strip()
|
||||
|
||||
|
||||
def build_sdist(
|
||||
sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
||||
) -> str:
|
||||
"""PEP 517 hook `build_sdist`."""
|
||||
args = ["build-backend", "build-sdist", sdist_directory]
|
||||
return call(args, config_settings)
|
||||
|
||||
|
||||
def build_wheel(
|
||||
wheel_directory: str,
|
||||
config_settings: "Mapping[Any, Any] | None" = None,
|
||||
metadata_directory: "str | None" = None,
|
||||
) -> str:
|
||||
"""PEP 517 hook `build_wheel`."""
|
||||
args = ["build-backend", "build-wheel", wheel_directory]
|
||||
if metadata_directory:
|
||||
args.extend(["--metadata-directory", metadata_directory])
|
||||
return call(args, config_settings)
|
||||
|
||||
|
||||
def get_requires_for_build_sdist(
|
||||
config_settings: "Mapping[Any, Any] | None" = None,
|
||||
) -> "Sequence[str]":
|
||||
"""PEP 517 hook `get_requires_for_build_sdist`."""
|
||||
warn_config_settings(config_settings)
|
||||
return []
|
||||
|
||||
|
||||
def get_requires_for_build_wheel(
|
||||
config_settings: "Mapping[Any, Any] | None" = None,
|
||||
) -> "Sequence[str]":
|
||||
"""PEP 517 hook `get_requires_for_build_wheel`."""
|
||||
warn_config_settings(config_settings)
|
||||
return []
|
||||
|
||||
|
||||
def prepare_metadata_for_build_wheel(
|
||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
||||
) -> str:
|
||||
"""PEP 517 hook `prepare_metadata_for_build_wheel`."""
|
||||
args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory]
|
||||
return call(args, config_settings)
|
||||
|
||||
|
||||
def build_editable(
|
||||
wheel_directory: str,
|
||||
config_settings: "Mapping[Any, Any] | None" = None,
|
||||
metadata_directory: "str | None" = None,
|
||||
) -> str:
|
||||
"""PEP 660 hook `build_editable`."""
|
||||
args = ["build-backend", "build-editable", wheel_directory]
|
||||
|
||||
if metadata_directory:
|
||||
args.extend(["--metadata-directory", metadata_directory])
|
||||
return call(args, config_settings)
|
||||
|
||||
|
||||
def get_requires_for_build_editable(
|
||||
config_settings: "Mapping[Any, Any] | None" = None,
|
||||
) -> "Sequence[str]":
|
||||
"""PEP 660 hook `get_requires_for_build_editable`."""
|
||||
warn_config_settings(config_settings)
|
||||
return []
|
||||
|
||||
|
||||
def prepare_metadata_for_build_editable(
|
||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
||||
) -> str:
|
||||
"""PEP 660 hook `prepare_metadata_for_build_editable`."""
|
||||
args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory]
|
||||
return call(args, config_settings)
|
@ -1,9 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.hf_api import GitRefInfo
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
@ -30,7 +31,7 @@ class KernelLock:
|
||||
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str):
|
||||
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
@ -44,7 +45,7 @@ def _get_available_versions(repo_id: str):
|
||||
return versions
|
||||
|
||||
|
||||
def get_kernel_locks(repo_id: str, version_spec: str):
|
||||
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
|
||||
@ -75,7 +76,7 @@ def get_kernel_locks(repo_id: str, version_spec: str):
|
||||
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
||||
)
|
||||
|
||||
variant_files = {}
|
||||
variant_files: Dict[str, List[Tuple[bytes, str]]] = {}
|
||||
for sibling in r.siblings:
|
||||
if sibling.rfilename.startswith("build/torch"):
|
||||
if sibling.blob_id is None:
|
||||
@ -96,9 +97,9 @@ def get_kernel_locks(repo_id: str, version_spec: str):
|
||||
variant_locks = {}
|
||||
for variant, files in variant_files.items():
|
||||
m = hashlib.sha256()
|
||||
for filename, hash in sorted(files):
|
||||
for filename_bytes, hash in sorted(files):
|
||||
# Filename as bytes.
|
||||
m.update(filename)
|
||||
m.update(filename_bytes)
|
||||
# Git blob or LFS file hash as bytes.
|
||||
m.update(bytes.fromhex(hash))
|
||||
|
||||
|
@ -21,7 +21,7 @@ from kernels.lockfile import KernelLock, VariantLock
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
|
||||
def build_variant():
|
||||
def build_variant() -> str:
|
||||
import torch
|
||||
|
||||
if torch.version.cuda is None:
|
||||
@ -38,12 +38,12 @@ def build_variant():
|
||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
||||
|
||||
|
||||
def noarch_build_variant():
|
||||
def noarch_build_variant() -> str:
|
||||
# Once we support other frameworks, detection goes here.
|
||||
return "torch-noarch"
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path):
|
||||
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
||||
# it would also be used for other imports. So, we make a module name that
|
||||
# depends on the path for it to be unique using the hex-encoded hash of
|
||||
@ -51,9 +51,13 @@ def import_from_path(module_name: str, file_path):
|
||||
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
||||
module_name = f"{module_name}_{path_hash}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if module is None:
|
||||
raise ImportError(f"Cannot load module {module_name} from spec")
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
return module
|
||||
|
||||
|
||||
@ -62,7 +66,7 @@ def install_kernel(
|
||||
revision: str,
|
||||
local_files_only: bool = False,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
) -> Tuple[str, str]:
|
||||
) -> Tuple[str, Path]:
|
||||
"""
|
||||
Download a kernel for the current environment to the cache.
|
||||
|
||||
@ -71,17 +75,20 @@ def install_kernel(
|
||||
package_name = package_name_from_repo_id(repo_id)
|
||||
variant = build_variant()
|
||||
noarch_variant = noarch_build_variant()
|
||||
repo_path = snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
)
|
||||
|
||||
variant_path = f"{repo_path}/build/{variant}"
|
||||
noarch_variant_path = f"{repo_path}/build/{noarch_variant}"
|
||||
if not os.path.exists(variant_path) and os.path.exists(noarch_variant_path):
|
||||
variant_path = repo_path / "build" / variant
|
||||
noarch_variant_path = repo_path / "build" / noarch_variant
|
||||
|
||||
if not variant_path.exists() and noarch_variant_path.exists():
|
||||
# Fall back to noarch variant.
|
||||
variant = noarch_variant
|
||||
variant_path = noarch_variant_path
|
||||
@ -92,7 +99,7 @@ def install_kernel(
|
||||
raise ValueError(f"No lock found for build variant: {variant}")
|
||||
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
|
||||
|
||||
module_init_path = f"{variant_path}/{package_name}/__init__.py"
|
||||
module_init_path = variant_path / package_name / "__init__.py"
|
||||
|
||||
if not os.path.exists(module_init_path):
|
||||
raise FileNotFoundError(
|
||||
@ -107,7 +114,7 @@ def install_kernel_all_variants(
|
||||
revision: str,
|
||||
local_files_only: bool = False,
|
||||
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||
) -> str:
|
||||
) -> Path:
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
@ -130,29 +137,15 @@ def install_kernel_all_variants(
|
||||
repo_path=repo_path, variant=variant, hash=variant_lock.hash
|
||||
)
|
||||
|
||||
return f"{repo_path}/build"
|
||||
return repo_path / "build"
|
||||
|
||||
|
||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
|
||||
with open(
|
||||
hf_hub_download(
|
||||
repo_id,
|
||||
"build.toml",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
),
|
||||
"rb",
|
||||
) as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
def get_kernel(repo_id: str, revision: str = "main"):
|
||||
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def load_kernel(repo_id: str):
|
||||
def load_kernel(repo_id: str) -> ModuleType:
|
||||
"""Get a pre-downloaded, locked kernel."""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
@ -166,30 +159,32 @@ def load_kernel(repo_id: str):
|
||||
variant = build_variant()
|
||||
noarch_variant = noarch_build_variant()
|
||||
|
||||
repo_path = snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
local_files_only=True,
|
||||
repo_path = Path(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=[f"build/{variant}/*", f"build/{noarch_variant}/*"],
|
||||
cache_dir=CACHE_DIR,
|
||||
local_files_only=True,
|
||||
)
|
||||
)
|
||||
|
||||
variant_path = f"{repo_path}/build/{variant}"
|
||||
noarch_variant_path = f"{repo_path}/build/{noarch_variant}"
|
||||
if not os.path.exists(variant_path) and os.path.exists(noarch_variant_path):
|
||||
variant_path = repo_path / "build" / variant
|
||||
noarch_variant_path = repo_path / "build" / noarch_variant
|
||||
if not variant_path.exists() and noarch_variant_path.exists():
|
||||
# Fall back to noarch variant.
|
||||
variant = noarch_variant
|
||||
variant_path = noarch_variant_path
|
||||
|
||||
module_init_path = f"{variant_path}/{package_name}/__init__.py"
|
||||
module_init_path = variant_path / package_name / "__init__.py"
|
||||
if not os.path.exists(module_init_path):
|
||||
raise FileNotFoundError(
|
||||
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
|
||||
)
|
||||
|
||||
return import_from_path(package_name, f"{variant_path}/{package_name}/__init__.py")
|
||||
return import_from_path(package_name, variant_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
|
||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||
"""Get a kernel using a lock file."""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
@ -200,7 +195,7 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False):
|
||||
repo_id, locked_sha, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
||||
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||
|
||||
|
||||
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
||||
@ -239,9 +234,9 @@ def _get_caller_module() -> Optional[ModuleType]:
|
||||
return first_module
|
||||
|
||||
|
||||
def validate_kernel(*, repo_path: str, variant: str, hash: str):
|
||||
def validate_kernel(*, repo_path: Path, variant: str, hash: str):
|
||||
"""Validate the given build variant of a kernel against a hasht."""
|
||||
variant_path = Path(repo_path) / "build" / variant
|
||||
variant_path = repo_path / "build" / variant
|
||||
|
||||
# Get the file paths. The first element is a byte-encoded relative path
|
||||
# used for sorting. The second element is the absolute path.
|
||||
@ -263,8 +258,8 @@ def validate_kernel(*, repo_path: str, variant: str, hash: str):
|
||||
|
||||
m = hashlib.sha256()
|
||||
|
||||
for filename, full_path in sorted(files):
|
||||
m.update(filename)
|
||||
for filename_bytes, full_path in sorted(files):
|
||||
m.update(filename_bytes)
|
||||
|
||||
blob_filename = full_path.resolve().name
|
||||
if len(blob_filename) == 40:
|
||||
|
Reference in New Issue
Block a user