Support kernel cache directory with HF_KERNELS_CACHE env var (#18)

This commit is contained in:
Daniël de Kok
2025-02-04 20:18:43 +01:00
committed by GitHub
parent 03875be8a0
commit d7f3831992

View File

@ -15,6 +15,8 @@ from packaging.version import parse
from hf_kernels.compat import tomllib
from hf_kernels.lockfile import KernelLock
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
def build_variant():
import torch
@ -43,6 +45,7 @@ def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
repo_path = snapshot_download(
repo_id,
allow_patterns=f"build/{build_variant()}/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
@ -55,6 +58,7 @@ def install_kernel_all_variants(
snapshot_download(
repo_id,
allow_patterns="build/*",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
)
@ -63,7 +67,11 @@ def install_kernel_all_variants(
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
with open(
hf_hub_download(
repo_id, "build.toml", revision=revision, local_files_only=local_files_only
repo_id,
"build.toml",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
),
"rb",
) as f:
@ -83,7 +91,11 @@ def load_kernel(repo_id: str):
raise ValueError(f"Kernel `{repo_id}` is not locked")
filename = hf_hub_download(
repo_id, "build.toml", local_files_only=True, revision=locked_sha
repo_id,
"build.toml",
cache_dir=CACHE_DIR,
local_files_only=True,
revision=locked_sha,
)
with open(filename, "rb") as f:
metadata = tomllib.load(f)