Compare commits

...

2 Commits

Author SHA1 Message Date
ed048616fe Set version to 0.10.4.dev0 (#169) 2025-10-16 20:21:35 +02:00
b182cd3458 feat: allow get_kernel to log telemetry. (#167)
* feat: allow get_kernel to log telemetry.

* Apply suggestions from code review

Co-authored-by: Daniël de Kok <me@danieldk.eu>

* doc

---------

Co-authored-by: Daniël de Kok <me@danieldk.eu>
2025-10-16 20:16:41 +02:00
3 changed files with 49 additions and 4 deletions

View File

@ -39,3 +39,13 @@ The approach of `forward`-replacement is the least invasive, because
it preserves the original model graph. It is also reversible, since
even though the `forward` of a layer _instance_ might be replaced,
the corresponding class still has the original `forward`.
## Misc
### How can I disable kernel reporting in the user-agent?
By default, we collect telemetry when a call to `get_kernel()` is made.
This only includes the `kernels` version, `torch` version, and the build
information for the kernel being requested.
You can disable this by setting `export DISABLE_TELEMETRY=yes`.

View File

@ -1,6 +1,6 @@
[project]
name = "kernels"
version = "0.10.3.dev0"
version = "0.10.4.dev0"
description = "Download compute kernels"
authors = [
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },

View File

@ -11,7 +11,7 @@ import sys
from importlib.metadata import Distribution
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
from huggingface_hub import file_exists, snapshot_download
from packaging.version import parse
@ -19,6 +19,8 @@ from packaging.version import parse
from kernels._versions import select_revision_or_version
from kernels.lockfile import KernelLock, VariantLock
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
def _get_cache_dir() -> Optional[str]:
"""Returns the kernels cache directory."""
@ -108,6 +110,7 @@ def install_kernel(
revision: str,
local_files_only: bool = False,
variant_locks: Optional[Dict[str, VariantLock]] = None,
user_agent: Optional[Union[str, dict]] = None,
) -> Tuple[str, Path]:
"""
Download a kernel for the current environment to the cache.
@ -123,6 +126,8 @@ def install_kernel(
Whether to only use local files and not download from the Hub.
variant_locks (`Dict[str, VariantLock]`, *optional*):
Optional dictionary of variant locks for validation.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
Returns:
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
@ -130,6 +135,7 @@ def install_kernel(
package_name = package_name_from_repo_id(repo_id)
variant = build_variant()
universal_variant = universal_build_variant()
user_agent = _get_user_agent(user_agent=user_agent)
repo_path = Path(
snapshot_download(
repo_id,
@ -137,6 +143,7 @@ def install_kernel(
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
user_agent=user_agent,
)
)
@ -213,7 +220,10 @@ def install_kernel_all_variants(
def get_kernel(
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
repo_id: str,
revision: Optional[str] = None,
version: Optional[str] = None,
user_agent: Optional[Union[str, dict]] = None,
) -> ModuleType:
"""
Load a kernel from the kernel hub.
@ -229,6 +239,8 @@ def get_kernel(
version (`str`, *optional*):
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
Cannot be used together with `revision`.
user_agent (`Union[str, dict]`, *optional*):
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
Returns:
`ModuleType`: The imported kernel module.
@ -245,7 +257,9 @@ def get_kernel(
```
"""
revision = select_revision_or_version(repo_id, revision, version)
package_name, package_path = install_kernel(repo_id, revision=revision)
package_name, package_path = install_kernel(
repo_id, revision=revision, user_agent=user_agent
)
return import_from_path(package_name, package_path / package_name / "__init__.py")
@ -501,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
def package_name_from_repo_id(repo_id: str) -> str:
return repo_id.split("/")[-1].replace("-", "_")
def _get_user_agent(
user_agent: Optional[Union[dict, str]] = None,
) -> Union[None, dict, str]:
import torch
from . import __version__
if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
return None
if user_agent is None:
user_agent = {
"kernels": __version__,
"torch": torch.__version__,
"build_variant": build_variant(),
"file_type": "kernel",
}
return user_agent