From 8069e3bf0c60c0792f2e1886e97ed900341570b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 24 Jul 2025 16:21:54 +0200 Subject: [PATCH] Update documentation for compatibility with doc-builder (#117) --- README.md | 15 +- docs/docker.md | 8 - docs/source/_toctree.yml | 28 ++ docs/source/api/kernels.md | 21 ++ docs/source/api/layers.md | 41 +++ docs/source/basic-usage.md | 34 +++ docs/{ => source}/env.md | 0 docs/{ => source}/faq.md | 0 docs/source/index.md | 20 ++ docs/source/installation.md | 16 ++ docs/{ => source}/kernel-requirements.md | 0 docs/{ => source}/layers.md | 0 docs/{ => source}/locking.md | 0 flake.nix | 4 + pyproject.toml | 3 + src/kernels/layer.py | 329 ++++++++++++++++++----- src/kernels/utils.py | 87 ++++-- 17 files changed, 506 insertions(+), 100 deletions(-) delete mode 100644 docs/docker.md create mode 100644 docs/source/_toctree.yml create mode 100644 docs/source/api/kernels.md create mode 100644 docs/source/api/layers.md create mode 100644 docs/source/basic-usage.md rename docs/{ => source}/env.md (100%) rename docs/{ => source}/faq.md (100%) create mode 100644 docs/source/index.md create mode 100644 docs/source/installation.md rename docs/{ => source}/kernel-requirements.md (100%) rename docs/{ => source}/layers.md (100%) rename docs/{ => source}/locking.md (100%) diff --git a/README.md b/README.md index 7f33bf8..0fc3907 100644 --- a/README.md +++ b/README.md @@ -56,10 +56,13 @@ the Hub. ## 📚 Documentation -- [Using layers](docs/layers.md) -- [Locking kernel/layer versions](docs/locking.md) -- [Environment variables](docs/env.md) -- [Using kernels in a Docker container](docs/docker.md) -- [Kernel requirements](docs/kernel-requirements.md) -- [Frequently Asked Questions](docs/faq.md) +- [Introduction](docs/source/index.md) +- [Installation](docs/source/installation.md) +- [Basic usage](docs/source/basic-usage.md) +- [Using layers](docs/source/layers.md) +- [Locking kernel/layer versions](docs/source/locking.md) +- [Environment variables](docs/source/env.md) +- [Using kernels in a Docker container](docs/source/docker.md) +- [Kernel requirements](docs/source/kernel-requirements.md) +- [Frequently Asked Questions](docs/source/faq.md) - [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/) diff --git a/docs/docker.md b/docs/docker.md deleted file mode 100644 index 8ed3bb8..0000000 --- a/docs/docker.md +++ /dev/null @@ -1,8 +0,0 @@ -# Using kernels in a Docker container - -build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands: - -```bash -docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . -docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference -``` diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml new file mode 100644 index 0000000..6067331 --- /dev/null +++ b/docs/source/_toctree.yml @@ -0,0 +1,28 @@ +- sections: + - local: index + title: Introduction + - local: installation + title: Installation + title: Getting started +- sections: + - local: basic-usage + title: Basic Usage + - local: layers + title: Using Layers + - local: locking + title: Locking Kernel Versions + - local: env + title: Environment Variables + - local: faq + title: FAQ + title: Usage Guide +- sections: + - local: api/kernels + title: Kernels + - local: api/layers + title: Layers + title: API Reference +- sections: + - local: kernel-requirements + title: Kernel Requirements + title: Developer Guide diff --git a/docs/source/api/kernels.md b/docs/source/api/kernels.md new file mode 100644 index 0000000..9b05bce --- /dev/null +++ b/docs/source/api/kernels.md @@ -0,0 +1,21 @@ +# Kernels API Reference + +## Main Functions + +### get_kernel + +[[autodoc]] kernels.get_kernel + +### has_kernel + +[[autodoc]] kernels.has_kernel + +## Loading locked kernels + +### load_kernel + +[[autodoc]] kernels.load_kernel + +### get_locked_kernel + +[[autodoc]] kernels.get_locked_kernel diff --git a/docs/source/api/layers.md b/docs/source/api/layers.md new file mode 100644 index 0000000..b913914 --- /dev/null +++ b/docs/source/api/layers.md @@ -0,0 +1,41 @@ +# Layers API Reference + +## Making layers kernel-aware + +### use_kernel_forward_from_hub + +[[autodoc]] kernels.use_kernel_forward_from_hub + +### replace_kernel_forward_from_hub + +[[autodoc]] kernels.replace_kernel_forward_from_hub + +## Registering kernel mappings + +### use_kernel_mapping + +[[autodoc]] kernels.use_kernel_mapping + +### register_kernel_mapping + +[[autodoc]] kernels.register_kernel_mapping + +## Kernelizing a model + +### kernelize + +[[autodoc]] kernels.kernelize + +## Classes + +### Device + +[[autodoc]] kernels.Device + +### Mode + +[[autodoc]] kernels.Mode + +### LayerRepository + +[[autodoc]] kernels.LayerRepository diff --git a/docs/source/basic-usage.md b/docs/source/basic-usage.md new file mode 100644 index 0000000..60f9975 --- /dev/null +++ b/docs/source/basic-usage.md @@ -0,0 +1,34 @@ +# Basic Usage + +## Loading Kernels + +Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub: + +```python +import torch +from kernels import get_kernel + +# Download optimized kernels from the Hugging Face hub +activation = get_kernel("kernels-community/activation") + +# Create a random tensor +x = torch.randn((10, 10), dtype=torch.float16, device="cuda") + +# Run the kernel +y = torch.empty_like(x) +activation.gelu_fast(y, x) + +print(y) +``` + +## Checking Kernel Availability + +You can check if a specific kernel is available for your environment: + +```python +from kernels import has_kernel + +# Check if kernel is available for current environment +is_available = has_kernel("kernels-community/activation") +print(f"Kernel available: {is_available}") +``` diff --git a/docs/env.md b/docs/source/env.md similarity index 100% rename from docs/env.md rename to docs/source/env.md diff --git a/docs/faq.md b/docs/source/faq.md similarity index 100% rename from docs/faq.md rename to docs/source/faq.md diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 0000000..3d44229 --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,20 @@ +# Kernels + +
+kernel-builder logo +
+ +The Kernel Hub allows Python libraries and applications to load compute +kernels directly from the [Hub](https://hf.co/). To support this kind +of dynamic loading, Hub kernels differ from traditional Python kernel +packages in that they are made to be: + +- **Portable**: a kernel can be loaded from paths outside `PYTHONPATH`. +- **Unique**: multiple versions of the same kernel can be loaded in the + same Python process. +- **Compatible**: kernels must support all recent versions of Python and + the different PyTorch build configurations (various CUDA versions + and C++ ABIs). Furthermore, older C library versions must be supported. + +You can [search for kernels](https://huggingface.co/models?other=kernel) on +the Hub. diff --git a/docs/source/installation.md b/docs/source/installation.md new file mode 100644 index 0000000..c6e1914 --- /dev/null +++ b/docs/source/installation.md @@ -0,0 +1,16 @@ +# Installation + +Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA): + +```bash +pip install kernels +``` + +# Using kernels in a Docker container + +Build and run the reference `examples/basic.py` in a Docker container with the following commands: + +```bash +docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference . +docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference +``` diff --git a/docs/kernel-requirements.md b/docs/source/kernel-requirements.md similarity index 100% rename from docs/kernel-requirements.md rename to docs/source/kernel-requirements.md diff --git a/docs/layers.md b/docs/source/layers.md similarity index 100% rename from docs/layers.md rename to docs/source/layers.md diff --git a/docs/locking.md b/docs/source/locking.md similarity index 100% rename from docs/locking.md rename to docs/source/locking.md diff --git a/flake.nix b/flake.nix index 52db12c..8d1d33e 100644 --- a/flake.nix +++ b/flake.nix @@ -26,6 +26,10 @@ formatter = pkgs.nixfmt-tree; devShells = with pkgs; rec { default = mkShell { + nativeBuildInputs = [ + # For hf-doc-builder. + nodejs + ]; buildInputs = [ black diff --git a/pyproject.toml b/pyproject.toml index eb4aa0b..107773e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,9 @@ dev = [ [project.optional-dependencies] torch = ["torch"] +docs = [ + "hf-doc-builder", +] [project.scripts] kernels = "kernels.cli:main" diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 8a08fd3..8d7b172 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -40,17 +40,19 @@ class Mode(Flag): """ Kernelize mode - The `Mode` flag is used by `kernelize` to select kernels for the given - mode. Mappings can be registered for specific modes. + The `Mode` flag is used by [`kernelize`] to select kernels for the given mode. Mappings can be registered for + specific modes. - * `INFERENCE`: The kernel is used for inference. - * `TRAINING`: The kernel is used for training. - * `TORCH_COMPILE`: The kernel is used with `torch.compile`. - * `FALLBACK`: In a kernel mapping, this kernel is used when no other mode - matches. + Attributes: + INFERENCE: The kernel is used for inference. + TRAINING: The kernel is used for training. + TORCH_COMPILE: The kernel is used with `torch.compile`. + FALLBACK: In a kernel mapping, this kernel is used when no other mode matches. + + Note: + Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE` should be used for layers that + are used for inference *with* `torch.compile`. - Different modes can be combined. For instance, `INFERENCE | TORCH_COMPILE` - should be used for layers that are used for inference *with* `torch.compile`. """ _NONE = 0 @@ -73,6 +75,36 @@ class Mode(Flag): @dataclass(frozen=True) class Device: + """ + Represents a compute device with optional properties. + + This class encapsulates device information including device type and optional device-specific properties + like CUDA capabilities. + + Args: + type (`str`): + The device type (e.g., "cuda", "mps", "cpu"). + properties ([`CUDAProperties`], *optional*): + Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. + + Example: + ```python + from kernels import Device, CUDAProperties + + # Basic CUDA device + cuda_device = Device(type="cuda") + + # CUDA device with specific capability requirements + cuda_device_with_props = Device( + type="cuda", + properties=CUDAProperties(min_capability=75, max_capability=90) + ) + + # MPS device for Apple Silicon + mps_device = Device(type="mps") + ``` + """ + type: str properties: Optional[CUDAProperties] = None @@ -101,6 +133,34 @@ class Device: @dataclass(frozen=True) class CUDAProperties: + """ + CUDA-specific device properties for capability-based kernel selection. + + This class defines CUDA compute capability constraints for kernel selection, allowing kernels to specify + minimum and maximum CUDA compute capabilities they support. + + Args: + min_capability (`int`): + Minimum CUDA compute capability required (e.g., 75 for compute capability 7.5). + max_capability (`int`): + Maximum CUDA compute capability supported (e.g., 90 for compute capability 9.0). + + Example: + ```python + from kernels import CUDAProperties, Device + + # Define CUDA properties for modern GPUs (compute capability 7.5 to 9.0) + cuda_props = CUDAProperties(min_capability=75, max_capability=90) + + # Create a device with these properties + device = Device(type="cuda", properties=cuda_props) + ``` + + Note: + CUDA compute capabilities are represented as integers where the major and minor versions are concatenated. + For example, compute capability 7.5 is represented as 75, and 8.6 is represented as 86. + """ + min_capability: int max_capability: int @@ -129,7 +189,36 @@ class LayerRepositoryProtocol(Protocol): class LayerRepository: """ - Repository and name of a layer. + Repository and name of a layer for kernel mapping. + + Args: + repo_id (`str`): + The Hub repository containing the layer. + layer_name (`str`): + The name of the layer within the kernel repository. + revision (`str`, *optional*, defaults to `"main"`): + The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. + version (`str`, *optional*): + The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. + Cannot be used together with `revision`. + + Example: + ```python + from kernels import LayerRepository + + # Reference a specific layer by revision + layer_repo = LayerRepository( + repo_id="username/my-kernel", + layer_name="MyLayer", + ) + + # Reference a layer by version constraint + layer_repo_versioned = LayerRepository( + repo_id="username/my-kernel", + layer_name="MyLayer", + version=">=1.0.0,<2.0.0" + ) + ``` """ def __init__( @@ -140,19 +229,6 @@ class LayerRepository: revision: Optional[str] = None, version: Optional[str] = None, ): - """ - Construct a layer repository. - - Args: - repo_id (`str`): The Hub repository containing the layer. - revision (`str`, *optional*, defaults to `"main"`): The specific - revision (branch, tag, or commit) to download. - Cannot be used together with `version`. - version (`str`, *optional*): The kernel version to download. This - can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. - Cannot be used together with `revision`. - """ - if revision is not None and version is not None: raise ValueError( "Either a revision or a version must be specified, not both." @@ -326,11 +402,42 @@ def use_kernel_mapping( inherit_mapping: bool = True, ): """ - Context manager that sets a mapping for a duration of the context. + Context manager that sets a kernel mapping for the duration of the context. - When `inherit_mapping` is set to `True` the current mapping will be - extended by `mapping` inside the context. If it is `False`, only - `mapping` is used inside the context. + This function allows temporary kernel mappings to be applied within a specific context, enabling different + kernel configurations for different parts of your code. + + Args: + mapping (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`): + The kernel mapping to apply. Maps layer names to device-specific kernel configurations. + inherit_mapping (`bool`, *optional*, defaults to `True`): + When `True`, the current mapping will be extended by `mapping` inside the context. When `False`, + only `mapping` is used inside the context. + + Returns: + Context manager that handles the temporary kernel mapping. + + Example: + ```python + from kernels import use_kernel_mapping, LayerRepository, Device + + # Define a mapping + mapping = { + "LayerNorm": { + "cuda": LayerRepository( + repo_id="username/experimental-kernels", + layer_name="FastLayerNorm" + ) + } + } + + # Use the mapping for the duration of the context. + with use_kernel_mapping(mapping): + # kernelize uses the temporary mapping + model = kernelize(model) + + # Outside the context, original mappings are restored + ``` """ class ContextManager: @@ -358,26 +465,49 @@ def register_kernel_mapping( ], ): """ - Allows one to register a mapping between a layer name and the corresponding - kernel(s) to use, depending on the device. This should be used in conjunction - with `kernelize`. + Register a global mapping between layer names and their corresponding kernel implementations. - Example usage: + This function allows you to register a mapping between a layer name and the corresponding kernel(s) to use, + depending on the device and mode. This should be used in conjunction with [`kernelize`]. - ```python - from kernels import LayerRepository, register_kernel_mapping + Args: + mapping (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`): + The kernel mapping to register globally. Maps layer names to device-specific kernels. + The mapping can specify different kernels for different modes (training, inference, etc.). - kernel_layer_mapping = { - "LlamaRMSNorm": { - "cuda": LayerRepository( - repo_id="kernels-community/activation", - layer_name="RmsNorm", - revision="layers", - ), - }, - } - register_kernel_mapping(kernel_layer_mapping) - ``` + Example: + ```python + from kernels import LayerRepository, register_kernel_mapping, Mode + + # Simple mapping for a single kernel per device + kernel_layer_mapping = { + "LlamaRMSNorm": { + "cuda": LayerRepository( + repo_id="kernels-community/activation", + layer_name="RmsNorm", + revision="layers", + ), + }, + } + register_kernel_mapping(kernel_layer_mapping) + + # Advanced mapping with mode-specific kernels + advanced_mapping = { + "MultiHeadAttention": { + "cuda": { + Mode.TRAINING: LayerRepository( + repo_id="username/training-kernels", + layer_name="TrainingAttention" + ), + Mode.INFERENCE: LayerRepository( + repo_id="username/inference-kernels", + layer_name="FastAttention" + ), + } + } + } + register_kernel_mapping(advanced_mapping) + ``` """ # Merge with existing mappings. for new_kernel, new_device_repos in mapping.items(): @@ -401,15 +531,20 @@ def replace_kernel_forward_from_hub( layer_name: str, ): """ - Decorator that prepares a layer class to use a kernel from the Hugging Face Hub. + Function that prepares a layer class to use kernels from the Hugging Face Hub. - This decorator stores the layer name and original forward method, which will be used - by the kernelize function to replace the forward implementation with the appropriate - kernel from the hub. + It is recommended to use [`use_kernel_forward_from_hub`] decorator instead. + This function should only be used as a last resort to extend third-party layers, + it is inherently fragile since the member variables and `forward` signature + of usch a layer can change. - Args: - cls: The layer class to decorate - layer_name: The name of the layer to use for kernel lookup + Example: + ```python + from kernels import replace_kernel_forward_from_hub + import torch.nn as nn + + replace_kernel_forward_from_hub(nn.LayerNorm, "LayerNorm") + ``` """ cls.kernel_layer_name = layer_name @@ -468,23 +603,57 @@ def kernelize( use_fallback: bool = True, ): """ - Iterate over all modules in the model and replace the `forward` method of - extensible layers for which kernels are registered using `register_kernel_mapping` - or `use_kernel_mapping`. + Replace layer forward methods with optimized kernel implementations. + + This function iterates over all modules in the model and replaces the `forward` method of extensible layers + for which kernels are registered using [`register_kernel_mapping`] or [`use_kernel_mapping`]. Args: - model: The PyTorch model to kernelize - mode: the mode that the kernel is going to be used in (e.g. - `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training - and `torch.compile`). - device: The device type to load kernels for. The device type will be inferred - from the parameters of the model when not provided. - use_fallback: Whether to use the original forward method of modules when no - compatible kernel could be found. If set to `False`, an exception will - be raised in such cases. + model (`nn.Module`): + The PyTorch model to kernelize. + mode ([`Mode`], *optional*, defaults to `Mode.TRAINING | Mode.TORCH_COMPILE`): + The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE` + kernelizes the model for training with `torch.compile`. + device (`Union[str, torch.device]`, *optional*): + The device type to load kernels for. The device type will be inferred from the model parameters + when not provided. + use_fallback (`bool`, *optional*, defaults to `True`): + Whether to use the original forward method of modules when no compatible kernel could be found. + If set to `False`, an exception will be raised in such cases. Returns: - The kernelized model + `nn.Module`: The kernelized model with optimized kernel implementations. + + Example: + ```python + from kernels import kernelize, Mode, register_kernel_mapping, LayerRepository + import torch.nn as nn + + @use_kernel_forward_from_hub("LayerNorm") + class LayerNorm(nn.Module): + ... + + # First register some kernel mappings + mapping = { + "LayerNorm": { + "cuda": LayerRepository( + repo_id="username/fast-kernels", + layer_name="FastLayerNorm" + ) + } + } + register_kernel_mapping(mapping) + + # Create and kernelize a model + model = nn.Sequential( + nn.Linear(768, 768), + LayerNorm(768), + nn.Linear(768, 768) + ) + + # Kernelize for inference + kernelized_model = kernelize(model) + ``` """ import torch @@ -593,7 +762,37 @@ def kernelize( def use_kernel_forward_from_hub(layer_name: str): """ - Make a layer extensible using the name `layer_name`. + Decorator factory that makes a layer extensible using the specified layer name. + + This is a decorator factory that returns a decorator which prepares a layer class to use kernels from the + Hugging Face Hub. + + Args: + layer_name (`str`): + The name of the layer to use for kernel lookup in registered mappings. + + Returns: + `Callable`: A decorator function that can be applied to layer classes. + + Example: + ```python + from kernels import use_kernel_forward_from_hub + import torch.nn as nn + + @use_kernel_forward_from_hub("MyCustomLayer") + class MyCustomLayer(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.hidden_size = hidden_size + + def forward(self, x): + # original implementation + return x + + # The layer can now be kernelized + model = MyCustomLayer(768) + kernelized_model = kernelize(model) + ``` """ def decorator(cls): diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 17af887..a24d610 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -98,7 +98,20 @@ def install_kernel( """ Download a kernel for the current environment to the cache. - The output path is validated againt `hash` when set. + The output path is validated against the hashes in `variant_locks` when provided. + + Args: + repo_id (`str`): + The Hub repository containing the kernel. + revision (`str`): + The specific revision (branch, tag, or commit) to download. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only use local files and not download from the Hub. + variant_locks (`Dict[str, VariantLock]`, *optional*): + Optional dictionary of variant locks for validation. + + Returns: + `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory. """ package_name = package_name_from_repo_id(repo_id) variant = build_variant() @@ -190,18 +203,22 @@ def get_kernel( ) -> ModuleType: """ Load a kernel from the kernel hub. - This function downloads a kernel to the local Hugging Face Hub cache - directory (if it was not downloaded before) and then loads the kernel. + + This function downloads a kernel to the local Hugging Face Hub cache directory (if it was not downloaded before) + and then loads the kernel. + Args: - repo_id (`str`): The Hub repository containing the kernel. - revision (`str`, *optional*, defaults to `"main"`): The specific - revision (branch, tag, or commit) to download. - Cannot be used together with `version`. - version (`str`, *optional*): The kernel version to download. This - can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. + repo_id (`str`): + The Hub repository containing the kernel. + revision (`str`, *optional*, defaults to `"main"`): + The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. + version (`str`, *optional*): + The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. Cannot be used together with `revision`. + Returns: `ModuleType`: The imported kernel module. + Example: ```python from kernels import get_kernel @@ -217,6 +234,15 @@ def get_kernel( def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: """ Import a kernel from a local kernel repository path. + + Args: + repo_path (`Path`): + The local path to the kernel repository. + package_name (`str`): + The name of the package to import from the repository. + + Returns: + `ModuleType`: The imported kernel module. """ package_name, package_path = _load_kernel_from_path(repo_path, package_name) return import_from_path(package_name, package_path / package_name / "__init__.py") @@ -226,19 +252,19 @@ def has_kernel( repo_id: str, revision: Optional[str] = None, version: Optional[str] = None ) -> bool: """ - Check whether a kernel build exists for the current environment - (Torch version and compute framework). + Check whether a kernel build exists for the current environment (Torch version and compute framework). Args: - repo_id (`str`): The Hub repository containing the kernel. - revision (`str`, *optional*, defaults to `"main"`): The specific - revision (branch, tag, or commit) to download. - Cannot be used together with `version`. - version (`str`, *optional*): The kernel version to download. This - can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. + repo_id (`str`): + The Hub repository containing the kernel. + revision (`str`, *optional*, defaults to `"main"`): + The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. + version (`str`, *optional*): + The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. Cannot be used together with `revision`. + Returns: - `bool`: `true` if a kernel is avaialble for the current environment. + `bool`: `True` if a kernel is available for the current environment. """ revision = select_revision_or_version(repo_id, revision, version) @@ -264,8 +290,16 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: """ Get a pre-downloaded, locked kernel. - If `lockfile` is not specified, the lockfile will be loaded from the - caller's package metadata. + If `lockfile` is not specified, the lockfile will be loaded from the caller's package metadata. + + Args: + repo_id (`str`): + The Hub repository containing the kernel. + lockfile (`Path`, *optional*): + Path to the lockfile. If not provided, the lockfile will be loaded from the caller's package metadata. + + Returns: + `ModuleType`: The imported kernel module. """ if lockfile is None: locked_sha = _get_caller_locked_kernel(repo_id) @@ -310,7 +344,18 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType: def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: - """Get a kernel using a lock file.""" + """ + Get a kernel using a lock file. + + Args: + repo_id (`str`): + The Hub repository containing the kernel. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only use local files and not download from the Hub. + + Returns: + `ModuleType`: The imported kernel module. + """ locked_sha = _get_caller_locked_kernel(repo_id) if locked_sha is None: