mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-20 20:56:31 +08:00
Support kernel cache directory with HF_KERNELS_CACHE
env var (#18)
This commit is contained in:
@ -15,6 +15,8 @@ from packaging.version import parse
|
||||
from hf_kernels.compat import tomllib
|
||||
from hf_kernels.lockfile import KernelLock
|
||||
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
|
||||
def build_variant():
|
||||
import torch
|
||||
@ -43,6 +45,7 @@ def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
|
||||
repo_path = snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=f"build/{build_variant()}/*",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
@ -55,6 +58,7 @@ def install_kernel_all_variants(
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns="build/*",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
@ -63,7 +67,11 @@ def install_kernel_all_variants(
|
||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
|
||||
with open(
|
||||
hf_hub_download(
|
||||
repo_id, "build.toml", revision=revision, local_files_only=local_files_only
|
||||
repo_id,
|
||||
"build.toml",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
),
|
||||
"rb",
|
||||
) as f:
|
||||
@ -83,7 +91,11 @@ def load_kernel(repo_id: str):
|
||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
||||
|
||||
filename = hf_hub_download(
|
||||
repo_id, "build.toml", local_files_only=True, revision=locked_sha
|
||||
repo_id,
|
||||
"build.toml",
|
||||
cache_dir=CACHE_DIR,
|
||||
local_files_only=True,
|
||||
revision=locked_sha,
|
||||
)
|
||||
with open(filename, "rb") as f:
|
||||
metadata = tomllib.load(f)
|
||||
|
Reference in New Issue
Block a user