|
|
|
@ -11,7 +11,7 @@ import sys
|
|
|
|
|
from importlib.metadata import Distribution
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from types import ModuleType
|
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import file_exists, snapshot_download
|
|
|
|
|
from packaging.version import parse
|
|
|
|
@ -19,6 +19,8 @@ from packaging.version import parse
|
|
|
|
|
from kernels._versions import select_revision_or_version
|
|
|
|
|
from kernels.lockfile import KernelLock, VariantLock
|
|
|
|
|
|
|
|
|
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_cache_dir() -> Optional[str]:
|
|
|
|
|
"""Returns the kernels cache directory."""
|
|
|
|
@ -108,6 +110,7 @@ def install_kernel(
|
|
|
|
|
revision: str,
|
|
|
|
|
local_files_only: bool = False,
|
|
|
|
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
|
|
|
|
user_agent: Optional[Union[str, dict]] = None,
|
|
|
|
|
) -> Tuple[str, Path]:
|
|
|
|
|
"""
|
|
|
|
|
Download a kernel for the current environment to the cache.
|
|
|
|
@ -123,6 +126,8 @@ def install_kernel(
|
|
|
|
|
Whether to only use local files and not download from the Hub.
|
|
|
|
|
variant_locks (`Dict[str, VariantLock]`, *optional*):
|
|
|
|
|
Optional dictionary of variant locks for validation.
|
|
|
|
|
user_agent (`Union[str, dict]`, *optional*):
|
|
|
|
|
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
|
|
|
|
@ -130,6 +135,7 @@ def install_kernel(
|
|
|
|
|
package_name = package_name_from_repo_id(repo_id)
|
|
|
|
|
variant = build_variant()
|
|
|
|
|
universal_variant = universal_build_variant()
|
|
|
|
|
user_agent = _get_user_agent(user_agent=user_agent)
|
|
|
|
|
repo_path = Path(
|
|
|
|
|
snapshot_download(
|
|
|
|
|
repo_id,
|
|
|
|
@ -137,6 +143,7 @@ def install_kernel(
|
|
|
|
|
cache_dir=CACHE_DIR,
|
|
|
|
|
revision=revision,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -213,7 +220,10 @@ def install_kernel_all_variants(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_kernel(
|
|
|
|
|
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
|
|
|
|
|
repo_id: str,
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
version: Optional[str] = None,
|
|
|
|
|
user_agent: Optional[Union[str, dict]] = None,
|
|
|
|
|
) -> ModuleType:
|
|
|
|
|
"""
|
|
|
|
|
Load a kernel from the kernel hub.
|
|
|
|
@ -229,6 +239,8 @@ def get_kernel(
|
|
|
|
|
version (`str`, *optional*):
|
|
|
|
|
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
|
|
|
|
|
Cannot be used together with `revision`.
|
|
|
|
|
user_agent (`Union[str, dict]`, *optional*):
|
|
|
|
|
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
`ModuleType`: The imported kernel module.
|
|
|
|
@ -245,7 +257,9 @@ def get_kernel(
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
revision = select_revision_or_version(repo_id, revision, version)
|
|
|
|
|
package_name, package_path = install_kernel(repo_id, revision=revision)
|
|
|
|
|
package_name, package_path = install_kernel(
|
|
|
|
|
repo_id, revision=revision, user_agent=user_agent
|
|
|
|
|
)
|
|
|
|
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -501,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
|
|
|
|
|
|
|
|
|
|
def package_name_from_repo_id(repo_id: str) -> str:
|
|
|
|
|
return repo_id.split("/")[-1].replace("-", "_")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_user_agent(
|
|
|
|
|
user_agent: Optional[Union[dict, str]] = None,
|
|
|
|
|
) -> Union[None, dict, str]:
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from . import __version__
|
|
|
|
|
|
|
|
|
|
if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if user_agent is None:
|
|
|
|
|
user_agent = {
|
|
|
|
|
"kernels": __version__,
|
|
|
|
|
"torch": torch.__version__,
|
|
|
|
|
"build_variant": build_variant(),
|
|
|
|
|
"file_type": "kernel",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return user_agent
|
|
|
|
|