mirror of
https://github.com/huggingface/kernels.git
synced 2025-11-06 23:24:31 +08:00
Compare commits
19 Commits
bump-0.1.6
...
v0.2.0
| Author | SHA1 | Date | |
|---|---|---|---|
| 18ecd0ce69 | |||
| b4ef1d60e5 | |||
| a40756f306 | |||
| 3671158f47 | |||
| 2ddd473cf7 | |||
| 497dffb89e | |||
| f036fd09cb | |||
| 3e4c83c798 | |||
| 4116d6019e | |||
| bd166b348a | |||
| 386c2a104e | |||
| c7516b9e50 | |||
| a8dcd1f6bc | |||
| af7fdf9202 | |||
| 9426e7e290 | |||
| df2c165d61 | |||
| d89239464a | |||
| 3212affd9e | |||
| 7ff40a859c |
10
.github/workflows/lint.yml
vendored
Normal file
10
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
name: Lints
|
||||||
|
on: [push, pull_request]
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Run lints
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Run ruff
|
||||||
|
uses: astral-sh/ruff-action@v3
|
||||||
11
.github/workflows/test.yml
vendored
11
.github/workflows/test.yml
vendored
@ -1,4 +1,4 @@
|
|||||||
name: Test hf-kernels
|
name: Test kernels
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -26,6 +26,9 @@ jobs:
|
|||||||
python-version: ["3.10", "3.12"]
|
python-version: ["3.10", "3.12"]
|
||||||
torch-version: ["2.5.1", "2.6.0"]
|
torch-version: ["2.5.1", "2.6.0"]
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_PYTHON_PREFERENCE: only-managed
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@ -41,5 +44,11 @@ jobs:
|
|||||||
- name: Install the project
|
- name: Install the project
|
||||||
run: uv sync --all-extras --dev
|
run: uv sync --all-extras --dev
|
||||||
|
|
||||||
|
- name: Install setuptools for Triton-based test
|
||||||
|
run: uv pip install setuptools
|
||||||
|
|
||||||
|
- name: Check typing
|
||||||
|
run: uv run mypy src/kernels
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: uv run pytest tests
|
run: uv run pytest tests
|
||||||
|
|||||||
85
README.md
85
README.md
@ -1,11 +1,31 @@
|
|||||||
# hf-kernels
|
# kernels
|
||||||
|
|
||||||
Make sure you have `torch==2.5.1+cu124` installed.
|
The Kernel Hub allows Python libraries and applications to load compute
|
||||||
|
kernels directly from the [Hub](https://hf.co/). To support this kind
|
||||||
|
of dynamic loading, Hub kernels differ from traditional Python kernel
|
||||||
|
packages in that they are made to be:
|
||||||
|
|
||||||
|
- Portable: a kernel can be loaded from paths outside `PYTHONPATH`.
|
||||||
|
- Unique: multiple versions of the same kernel can be loaded in the
|
||||||
|
same Python process.
|
||||||
|
- Compatible: kernels must support all recent versions of Python and
|
||||||
|
the different PyTorch build configurations (various CUDA versions
|
||||||
|
and C++ ABIs). Furthermore, older C library versions must be supported.
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
Install the `kernels` package with `pip` (requires `torch>=2.5` and CUDA):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install kernels
|
||||||
|
```
|
||||||
|
|
||||||
|
Here is how you would use the [activation](https://huggingface.co/kernels-community/activation) kernels from the Hugging Face Hub:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hf_kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
# Download optimized kernels from the Hugging Face hub
|
# Download optimized kernels from the Hugging Face hub
|
||||||
activation = get_kernel("kernels-community/activation")
|
activation = get_kernel("kernels-community/activation")
|
||||||
@ -20,57 +40,12 @@ activation.gelu_fast(y, x)
|
|||||||
print(y)
|
print(y)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Docker Reference
|
You can [search for kernels](https://huggingface.co/models?other=kernel) on
|
||||||
|
the Hub.
|
||||||
|
|
||||||
build and run the reference [example/basic.py](example/basic.py) in a Docker container with the following commands:
|
## 📚 Documentation
|
||||||
|
|
||||||
```bash
|
- [Locking kernel versions](docs/locking.md)
|
||||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
- [Using kernels in a Docker container](docs/docker.md)
|
||||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
- [Kernel requirements](docs/kernel-requirements.md)
|
||||||
```
|
- [Writing kernels](https://github.com/huggingface/kernel-builder/blob/main/docs/writing-kernels.md) using [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
|
||||||
## Locking kernel versions
|
|
||||||
|
|
||||||
Projects that use `setuptools` can lock the kernel versions that should be
|
|
||||||
used. First specify the accepted versions in `pyproject.toml` and make
|
|
||||||
sure that `hf-kernels` is a build dependency:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[build-system]
|
|
||||||
requires = ["hf-kernels", "setuptools"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[tool.kernels.dependencies]
|
|
||||||
"kernels-community/activation" = ">=0.0.1"
|
|
||||||
```
|
|
||||||
|
|
||||||
Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
|
||||||
the locked revisions. The locked revision will be used when loading a kernel with
|
|
||||||
`get_locked_kernel`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from hf_kernels import get_locked_kernel
|
|
||||||
|
|
||||||
activation = get_locked_kernel("kernels-community/activation")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
|
||||||
to `hf-kernels` after doing an (editable or regular) installation of your project.
|
|
||||||
|
|
||||||
## Pre-downloading locked kernels
|
|
||||||
|
|
||||||
Locked kernels can be pre-downloaded by running `hf-kernel download .` in your
|
|
||||||
project directory. This will download the kernels to your local Hugging Face
|
|
||||||
Hub cache.
|
|
||||||
|
|
||||||
The pre-downloaded kernels are used by the `get_locked_kernel` function.
|
|
||||||
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
|
|
||||||
want kernel loading to error when a kernel is not pre-downloaded, you can use
|
|
||||||
the `load_kernel` function instead:
|
|
||||||
|
|
||||||
````python
|
|
||||||
```python
|
|
||||||
from hf_kernels import load_kernel
|
|
||||||
|
|
||||||
activation = load_kernel("kernels-community/activation")
|
|
||||||
````
|
|
||||||
|
|||||||
@ -31,13 +31,13 @@ WORKDIR /app/kernel-test
|
|||||||
# install python depdencies
|
# install python depdencies
|
||||||
RUN uv add torch==2.5.0 numpy
|
RUN uv add torch==2.5.0 numpy
|
||||||
|
|
||||||
# copy hf-kernels lib
|
# copy kernels lib
|
||||||
COPY src ./hf-kernels/src
|
COPY src ./kernels/src
|
||||||
COPY pyproject.toml ./hf-kernels/pyproject.toml
|
COPY pyproject.toml ./kernels/pyproject.toml
|
||||||
COPY README.md ./hf-kernels/README.md
|
COPY README.md ./kernels/README.md
|
||||||
|
|
||||||
# install library
|
# install library
|
||||||
RUN uv pip install -e hf-kernels
|
RUN uv pip install -e kernels
|
||||||
|
|
||||||
# copy examples
|
# copy examples
|
||||||
COPY examples ./examples
|
COPY examples ./examples
|
||||||
@ -48,4 +48,4 @@ ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
|||||||
|
|
||||||
# command to run the script
|
# command to run the script
|
||||||
CMD ["uv", "run", "examples/basic.py"]
|
CMD ["uv", "run", "examples/basic.py"]
|
||||||
# CMD ["ls", "hf-kernels"]
|
# CMD ["ls", "kernels"]
|
||||||
|
|||||||
8
docs/docker.md
Normal file
8
docs/docker.md
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Using kernels in a Docker container
|
||||||
|
|
||||||
|
build and run the reference [examples/basic.py](examples/basic.py) in a Docker container with the following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||||
|
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||||
|
```
|
||||||
103
docs/kernel-requirements.md
Normal file
103
docs/kernel-requirements.md
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Kernel requirements
|
||||||
|
|
||||||
|
Kernels on the Hub must fulfill the requirements outlined on this page.
|
||||||
|
You can use [kernel-builder](https://github.com/huggingface/kernel-builder/)
|
||||||
|
to build conforming kernels.
|
||||||
|
|
||||||
|
## Directory layout
|
||||||
|
|
||||||
|
A kernel repository on the Hub must contain a `build` directory. This
|
||||||
|
directory contains build variants of a kernel in the form of directories
|
||||||
|
following the template
|
||||||
|
`<framework><version>-cxx<abiver>-<cu><cudaver>-<arch>-<os>`.
|
||||||
|
For example `build/torch26-cxx98-cu118-x86_64-linux`. The currently
|
||||||
|
recommended build variants are:
|
||||||
|
|
||||||
|
- `torch25-cxx11-cu118-x86_64-linux`
|
||||||
|
- `torch25-cxx11-cu121-x86_64-linux`
|
||||||
|
- `torch25-cxx11-cu124-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu118-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu121-x86_64-linux`
|
||||||
|
- `torch25-cxx98-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu118-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx11-cu126-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu118-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu124-x86_64-linux`
|
||||||
|
- `torch26-cxx98-cu126-x86_64-linux`
|
||||||
|
|
||||||
|
This list will be updated as new PyTorch versions are released. Kernels
|
||||||
|
that are in pure Python (e.g. Triton kernels) only need to provide a
|
||||||
|
single build variant:
|
||||||
|
|
||||||
|
- `torch-universal`
|
||||||
|
|
||||||
|
Each variant directory should contain a single directory with the same name
|
||||||
|
as the repository (replacing `-` by `_`). For instance, kernels in the
|
||||||
|
`kernels-community/activation` repository have a directories like
|
||||||
|
`build/<variant>/activation`. This directory
|
||||||
|
must be a Python package with an `__init__.py` file.
|
||||||
|
|
||||||
|
## Native Python module
|
||||||
|
|
||||||
|
Kernels will typically contain a native Python module with precompiled
|
||||||
|
compute kernels and bindings. This module must fulfill the following
|
||||||
|
requirements:
|
||||||
|
|
||||||
|
- Use [ABI3/Limited API](https://docs.python.org/3/c-api/stable.html#stable-application-binary-interface)
|
||||||
|
for compatibility with Python 3.9 and later.
|
||||||
|
- Compatible with glibc 2.27 or later. This means that no symbols
|
||||||
|
from later versions must be used. To archive this, the module should
|
||||||
|
be built against this glibc version. **Warning:** libgcc must also be
|
||||||
|
built against glibc 2.27 to avoid leaking symbols.
|
||||||
|
- No dynamic linkage against libstdc++/libc++. Linkage for C++ symbols
|
||||||
|
must be static.
|
||||||
|
- No dynamic library dependencies outside Torch or CUDA libraries
|
||||||
|
installed as dependencies of Torch.
|
||||||
|
|
||||||
|
(These requirements will be updated as new PyTorch versions are released.)
|
||||||
|
|
||||||
|
## Torch extension
|
||||||
|
|
||||||
|
Torch native extension functions must be [registered](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
|
||||||
|
in `torch.ops.<namespace>`. Since we allow loading of multiple versions of
|
||||||
|
a module in the same Python process, `namespace` must be unique for each
|
||||||
|
version of a kernel. Failing to do so will create clashes when different
|
||||||
|
versions of the same kernel are loaded. Two suggested ways of doing this
|
||||||
|
are:
|
||||||
|
|
||||||
|
- Appending a truncated SHA-1 hash of the git commit that the kernel was
|
||||||
|
built from to the name of the extension.
|
||||||
|
- Appending random material to the name of the extension.
|
||||||
|
|
||||||
|
**Note:** we recommend against appending a version number or git tag.
|
||||||
|
Version numbers are typically not bumped on each commit, so users
|
||||||
|
might use two different commits that happen to have the same version
|
||||||
|
number. Git tags are not stable, so they do not provide a good way
|
||||||
|
of guaranteeing uniqueness of the namespace.
|
||||||
|
|
||||||
|
## Python requirements
|
||||||
|
|
||||||
|
- Python code must be compatible with Python 3.9 and later.
|
||||||
|
- All Python code imports from the kernel itself must be relative. So,
|
||||||
|
for instance if in the example kernel `example`,
|
||||||
|
`module_b` needs a function from `module_a`, import as:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .module_a import foo
|
||||||
|
```
|
||||||
|
|
||||||
|
**Never use:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
# DO NOT DO THIS!
|
||||||
|
|
||||||
|
from example.module_a import foo
|
||||||
|
```
|
||||||
|
|
||||||
|
The latter would import from the module `example` that is in Python's
|
||||||
|
global module dict. However, since we allow loading multiple versions
|
||||||
|
of a module, we uniquely name the module.
|
||||||
|
|
||||||
|
- Only modules from the Python standard library, Torch, or the kernel itself
|
||||||
|
can be imported.
|
||||||
44
docs/locking.md
Normal file
44
docs/locking.md
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Locking kernel versions
|
||||||
|
|
||||||
|
Projects that use `setuptools` can lock the kernel versions that should be
|
||||||
|
used. First specify the accepted versions in `pyproject.toml` and make
|
||||||
|
sure that `kernels` is a build dependency:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[build-system]
|
||||||
|
requires = ["kernels", "setuptools"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-community/activation" = ">=0.0.1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run `kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||||
|
the locked revisions. The locked revision will be used when loading a kernel with
|
||||||
|
`get_locked_kernel`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import get_locked_kernel
|
||||||
|
|
||||||
|
activation = get_locked_kernel("kernels-community/activation")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||||
|
to `kernels` after doing an (editable or regular) installation of your project.
|
||||||
|
|
||||||
|
## Pre-downloading locked kernels
|
||||||
|
|
||||||
|
Locked kernels can be pre-downloaded by running `kernel download .` in your
|
||||||
|
project directory. This will download the kernels to your local Hugging Face
|
||||||
|
Hub cache.
|
||||||
|
|
||||||
|
The pre-downloaded kernels are used by the `get_locked_kernel` function.
|
||||||
|
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
|
||||||
|
want kernel loading to error when a kernel is not pre-downloaded, you can use
|
||||||
|
the `load_kernel` function instead:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from kernels import load_kernel
|
||||||
|
|
||||||
|
activation = load_kernel("kernels-community/activation")
|
||||||
|
```
|
||||||
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from hf_kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
print("Starting examples/basic.py demo")
|
print("Starting examples/basic.py demo")
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "hf-kernels"
|
name = "kernels"
|
||||||
version = "0.1.5"
|
version = "0.2.0"
|
||||||
description = "Download cuda kernels"
|
description = "Download cuda kernels"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||||
@ -14,7 +14,7 @@ dependencies = [
|
|||||||
"huggingface-hub>=0.26.3",
|
"huggingface-hub>=0.26.3",
|
||||||
"packaging>=24.2",
|
"packaging>=24.2",
|
||||||
"tomli>=2.0.1; python_version<'3.11'",
|
"tomli>=2.0.1; python_version<'3.11'",
|
||||||
"torch>=2.4",
|
"torch>=2.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
@ -23,18 +23,42 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
"mypy == 1.14.1",
|
||||||
"pytest >=8",
|
"pytest >=8",
|
||||||
# Whatever version is compatible with pytest.
|
# Whatever version is compatible with pytest.
|
||||||
"pytest-benchmark",
|
"pytest-benchmark",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
hf-kernels = "hf_kernels.cli:main"
|
kernels = "kernels.cli:main"
|
||||||
|
|
||||||
[project.entry-points."egg_info.writers"]
|
[project.entry-points."egg_info.writers"]
|
||||||
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"
|
"kernels.lock" = "kernels.lockfile:write_egg_lockfile"
|
||||||
|
|
||||||
#[build-system]
|
|
||||||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
|
[tool.ruff]
|
||||||
#build-backend = "hf_kernels.build"
|
exclude = [
|
||||||
#backend-path = ["src"]
|
".eggs",
|
||||||
|
".git",
|
||||||
|
".git-rewrite",
|
||||||
|
".hg",
|
||||||
|
".mypy_cache",
|
||||||
|
".nox",
|
||||||
|
".pants.d",
|
||||||
|
".pytype",
|
||||||
|
".ruff_cache",
|
||||||
|
".svn",
|
||||||
|
".tox",
|
||||||
|
".venv",
|
||||||
|
".venv*",
|
||||||
|
"__pypackages__",
|
||||||
|
"_build",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"venv",
|
||||||
|
]
|
||||||
|
line-length = 119
|
||||||
|
# Ignored rules:
|
||||||
|
# "E501" -> line length violation
|
||||||
|
lint.ignore = ["E501"]
|
||||||
|
lint.select = ["E", "F", "I", "W"]
|
||||||
|
|||||||
@ -1,3 +0,0 @@
|
|||||||
from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel
|
|
||||||
|
|
||||||
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
|
|
||||||
@ -1,144 +0,0 @@
|
|||||||
"""
|
|
||||||
Python shims for the PEP 517 and PEP 660 build backend.
|
|
||||||
|
|
||||||
Major imports in this module are required to be lazy:
|
|
||||||
```
|
|
||||||
$ hyperfine \
|
|
||||||
"/usr/bin/python3 -c \"print('hi')\"" \
|
|
||||||
"/usr/bin/python3 -c \"from subprocess import check_call; print('hi')\""
|
|
||||||
Base: Time (mean ± σ): 11.0 ms ± 1.7 ms [User: 8.5 ms, System: 2.5 ms]
|
|
||||||
With import: Time (mean ± σ): 15.2 ms ± 2.0 ms [User: 12.3 ms, System: 2.9 ms]
|
|
||||||
Base 1.38 ± 0.28 times faster than with import
|
|
||||||
```
|
|
||||||
|
|
||||||
The same thing goes for the typing module, so we use Python 3.10 type annotations that
|
|
||||||
don't require importing typing but then quote them so earlier Python version ignore
|
|
||||||
them while IDEs and type checker can see through the quotes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
|
||||||
|
|
||||||
TYPE_CHECKING = False
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Mapping, Sequence # noqa:I001
|
|
||||||
from typing import Any # noqa:I001
|
|
||||||
|
|
||||||
|
|
||||||
def warn_config_settings(config_settings: "Mapping[Any, Any] | None" = None) -> None:
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if config_settings:
|
|
||||||
print("Warning: Config settings are not supported", file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
def call(
|
|
||||||
args: "Sequence[str]", config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""Invoke a uv subprocess and return the filename from stdout."""
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
# Unlike `find_uv_bin`, this mechanism must work according to PEP 517
|
|
||||||
import os
|
|
||||||
|
|
||||||
cwd = os.getcwd()
|
|
||||||
filename = os.path.join(cwd, "pyproject.toml")
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
data = tomllib.load(f)
|
|
||||||
|
|
||||||
for kernel, _ in (
|
|
||||||
data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items()
|
|
||||||
):
|
|
||||||
from hf_kernels.utils import install_kernel
|
|
||||||
|
|
||||||
install_kernel(kernel, revision="main")
|
|
||||||
uv_bin = shutil.which("uv")
|
|
||||||
if uv_bin is None:
|
|
||||||
raise RuntimeError("uv was not properly installed")
|
|
||||||
# Forward stderr, capture stdout for the filename
|
|
||||||
result = subprocess.run([uv_bin, *args], stdout=subprocess.PIPE)
|
|
||||||
if result.returncode != 0:
|
|
||||||
sys.exit(result.returncode)
|
|
||||||
# If there was extra stdout, forward it (there should not be extra stdout)
|
|
||||||
stdout = result.stdout.decode("utf-8").strip().splitlines(keepends=True)
|
|
||||||
sys.stdout.writelines(stdout[:-1])
|
|
||||||
# Fail explicitly instead of an irrelevant stacktrace
|
|
||||||
if not stdout:
|
|
||||||
print("uv subprocess did not return a filename on stdout", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
return stdout[-1].strip()
|
|
||||||
|
|
||||||
|
|
||||||
def build_sdist(
|
|
||||||
sdist_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `build_sdist`."""
|
|
||||||
args = ["build-backend", "build-sdist", sdist_directory]
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def build_wheel(
|
|
||||||
wheel_directory: str,
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
metadata_directory: "str | None" = None,
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `build_wheel`."""
|
|
||||||
args = ["build-backend", "build-wheel", wheel_directory]
|
|
||||||
if metadata_directory:
|
|
||||||
args.extend(["--metadata-directory", metadata_directory])
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_sdist(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 517 hook `get_requires_for_build_sdist`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_wheel(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 517 hook `get_requires_for_build_wheel`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_metadata_for_build_wheel(
|
|
||||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 517 hook `prepare_metadata_for_build_wheel`."""
|
|
||||||
args = ["build-backend", "prepare-metadata-for-build-wheel", metadata_directory]
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def build_editable(
|
|
||||||
wheel_directory: str,
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
metadata_directory: "str | None" = None,
|
|
||||||
) -> str:
|
|
||||||
"""PEP 660 hook `build_editable`."""
|
|
||||||
args = ["build-backend", "build-editable", wheel_directory]
|
|
||||||
|
|
||||||
if metadata_directory:
|
|
||||||
args.extend(["--metadata-directory", metadata_directory])
|
|
||||||
return call(args, config_settings)
|
|
||||||
|
|
||||||
|
|
||||||
def get_requires_for_build_editable(
|
|
||||||
config_settings: "Mapping[Any, Any] | None" = None,
|
|
||||||
) -> "Sequence[str]":
|
|
||||||
"""PEP 660 hook `get_requires_for_build_editable`."""
|
|
||||||
warn_config_settings(config_settings)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_metadata_for_build_editable(
|
|
||||||
metadata_directory: str, config_settings: "Mapping[Any, Any] | None" = None
|
|
||||||
) -> str:
|
|
||||||
"""PEP 660 hook `prepare_metadata_for_build_editable`."""
|
|
||||||
args = ["build-backend", "prepare-metadata-for-build-editable", metadata_directory]
|
|
||||||
return call(args, config_settings)
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
import ctypes
|
|
||||||
import importlib
|
|
||||||
import importlib.metadata
|
|
||||||
import inspect
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import sys
|
|
||||||
from importlib.metadata import Distribution
|
|
||||||
from types import ModuleType
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
|
||||||
from packaging.version import parse
|
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
|
||||||
from hf_kernels.lockfile import KernelLock
|
|
||||||
|
|
||||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
|
||||||
|
|
||||||
|
|
||||||
def build_variant():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
torch_version = parse(torch.__version__)
|
|
||||||
cuda_version = parse(torch.version.cuda)
|
|
||||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
|
||||||
cpu = platform.machine()
|
|
||||||
os = platform.system().lower()
|
|
||||||
|
|
||||||
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
|
||||||
|
|
||||||
|
|
||||||
def import_from_path(module_name: str, file_path):
|
|
||||||
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
|
||||||
# it would also be used for other imports. So, we make a module name that
|
|
||||||
# depends on the path for it to be unique using the hex-encoded hash of
|
|
||||||
# the path.
|
|
||||||
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
|
||||||
module_name = f"{module_name}_{path_hash}"
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
sys.modules[module_name] = module
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
|
|
||||||
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
|
|
||||||
"torch"
|
|
||||||
]["name"]
|
|
||||||
repo_path = snapshot_download(
|
|
||||||
repo_id,
|
|
||||||
allow_patterns=f"build/{build_variant()}/*",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
return package_name, f"{repo_path}/build/{build_variant()}"
|
|
||||||
|
|
||||||
|
|
||||||
def install_kernel_all_variants(
|
|
||||||
repo_id: str, revision: str, local_files_only: bool = False
|
|
||||||
):
|
|
||||||
snapshot_download(
|
|
||||||
repo_id,
|
|
||||||
allow_patterns="build/*",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
|
|
||||||
with open(
|
|
||||||
hf_hub_download(
|
|
||||||
repo_id,
|
|
||||||
"build.toml",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
revision=revision,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
),
|
|
||||||
"rb",
|
|
||||||
) as f:
|
|
||||||
return tomllib.load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_kernel(repo_id: str, revision: str = "main"):
|
|
||||||
package_name, package_path = install_kernel(repo_id, revision=revision)
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def load_kernel(repo_id: str):
|
|
||||||
"""Get a pre-downloaded, locked kernel."""
|
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
|
||||||
|
|
||||||
if locked_sha is None:
|
|
||||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
|
||||||
|
|
||||||
filename = hf_hub_download(
|
|
||||||
repo_id,
|
|
||||||
"build.toml",
|
|
||||||
cache_dir=CACHE_DIR,
|
|
||||||
local_files_only=True,
|
|
||||||
revision=locked_sha,
|
|
||||||
)
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
metadata = tomllib.load(f)
|
|
||||||
package_name = metadata["torch"]["name"]
|
|
||||||
|
|
||||||
repo_path = os.path.dirname(filename)
|
|
||||||
package_path = f"{repo_path}/build/{build_variant()}"
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
|
|
||||||
"""Get a kernel using a lock file."""
|
|
||||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
|
||||||
|
|
||||||
if locked_sha is None:
|
|
||||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
|
||||||
|
|
||||||
package_name, package_path = install_kernel(
|
|
||||||
repo_id, locked_sha, local_files_only=local_files_only
|
|
||||||
)
|
|
||||||
|
|
||||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
|
||||||
for dist in _get_caller_distributions():
|
|
||||||
lock_json = dist.read_text("hf-kernels.lock")
|
|
||||||
if lock_json is not None:
|
|
||||||
for kernel_lock_json in json.loads(lock_json):
|
|
||||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
|
||||||
if kernel_lock.repo_id == repo_id:
|
|
||||||
return kernel_lock.sha
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_distributions() -> List[Distribution]:
|
|
||||||
module = _get_caller_module()
|
|
||||||
if module is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Look up all possible distributions that this module could be from.
|
|
||||||
package = module.__name__.split(".")[0]
|
|
||||||
dist_names = importlib.metadata.packages_distributions().get(package)
|
|
||||||
if dist_names is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_caller_module() -> Optional[ModuleType]:
|
|
||||||
stack = inspect.stack()
|
|
||||||
# Get first module in the stack that is not the current module.
|
|
||||||
first_module = inspect.getmodule(stack[0][0])
|
|
||||||
for frame in stack[1:]:
|
|
||||||
module = inspect.getmodule(frame[0])
|
|
||||||
if module is not None and module != first_module:
|
|
||||||
return module
|
|
||||||
return first_module
|
|
||||||
3
src/kernels/__init__.py
Normal file
3
src/kernels/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from kernels.utils import get_kernel, get_locked_kernel, install_kernel, load_kernel
|
||||||
|
|
||||||
|
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
|
||||||
@ -4,14 +4,14 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
from kernels.compat import tomllib
|
||||||
from hf_kernels.lockfile import KernelLock, get_kernel_locks
|
from kernels.lockfile import KernelLock, get_kernel_locks
|
||||||
from hf_kernels.utils import install_kernel, install_kernel_all_variants
|
from kernels.utils import install_kernel, install_kernel_all_variants
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="hf-kernel", description="Manage compute kernels"
|
prog="kernel", description="Manage compute kernels"
|
||||||
)
|
)
|
||||||
subparsers = parser.add_subparsers(required=True)
|
subparsers = parser.add_subparsers(required=True)
|
||||||
|
|
||||||
@ -41,15 +41,17 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
def download_kernels(args):
|
def download_kernels(args):
|
||||||
lock_path = args.project_dir / "hf-kernels.lock"
|
lock_path = args.project_dir / "kernels.lock"
|
||||||
|
|
||||||
if not lock_path.exists():
|
if not lock_path.exists():
|
||||||
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
|
print(f"No kernels.lock file found in: {args.project_dir}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
with open(args.project_dir / "hf-kernels.lock", "r") as f:
|
with open(args.project_dir / "kernels.lock", "r") as f:
|
||||||
lock_json = json.load(f)
|
lock_json = json.load(f)
|
||||||
|
|
||||||
|
all_successful = True
|
||||||
|
|
||||||
for kernel_lock_json in lock_json:
|
for kernel_lock_json in lock_json:
|
||||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||||
print(
|
print(
|
||||||
@ -57,9 +59,22 @@ def download_kernels(args):
|
|||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
if args.all_variants:
|
if args.all_variants:
|
||||||
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
|
install_kernel_all_variants(
|
||||||
|
kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
|
try:
|
||||||
|
install_kernel(
|
||||||
|
kernel_lock.repo_id,
|
||||||
|
kernel_lock.sha,
|
||||||
|
variant_locks=kernel_lock.variants,
|
||||||
|
)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
print(e, file=sys.stderr)
|
||||||
|
all_successful = False
|
||||||
|
|
||||||
|
if not all_successful:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def lock_kernels(args):
|
def lock_kernels(args):
|
||||||
@ -72,7 +87,7 @@ def lock_kernels(args):
|
|||||||
for kernel, version in kernel_versions.items():
|
for kernel, version in kernel_versions.items():
|
||||||
all_locks.append(get_kernel_locks(kernel, version))
|
all_locks.append(get_kernel_locks(kernel, version))
|
||||||
|
|
||||||
with open(args.project_dir / "hf-kernels.lock", "w") as f:
|
with open(args.project_dir / "kernels.lock", "w") as f:
|
||||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||||
|
|
||||||
|
|
||||||
@ -1,33 +1,37 @@
|
|||||||
|
import hashlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.hf_api import GitRefInfo
|
||||||
from packaging.specifiers import SpecifierSet
|
from packaging.specifiers import SpecifierSet
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
from hf_kernels.compat import tomllib
|
from kernels.compat import tomllib
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileLock:
|
class VariantLock:
|
||||||
filename: str
|
hash: str
|
||||||
blob_id: str
|
hash_type: str = "git_lfs_concat"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KernelLock:
|
class KernelLock:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
sha: str
|
sha: str
|
||||||
files: List[FileLock]
|
variants: Dict[str, VariantLock]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, o: Dict):
|
def from_json(cls, o: Dict):
|
||||||
files = [FileLock(**f) for f in o["files"]]
|
variants = {
|
||||||
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
|
variant: VariantLock(**lock) for variant, lock in o["variants"].items()
|
||||||
|
}
|
||||||
|
return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants)
|
||||||
|
|
||||||
|
|
||||||
def _get_available_versions(repo_id: str):
|
def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]:
|
||||||
"""Get kernel versions that are available in the repository."""
|
"""Get kernel versions that are available in the repository."""
|
||||||
versions = {}
|
versions = {}
|
||||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||||
@ -41,7 +45,7 @@ def _get_available_versions(repo_id: str):
|
|||||||
return versions
|
return versions
|
||||||
|
|
||||||
|
|
||||||
def get_kernel_locks(repo_id: str, version_spec: str):
|
def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock:
|
||||||
"""
|
"""
|
||||||
Get the locks for a kernel with the given version spec.
|
Get the locks for a kernel with the given version spec.
|
||||||
|
|
||||||
@ -72,31 +76,55 @@ def get_kernel_locks(repo_id: str, version_spec: str):
|
|||||||
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
file_locks = []
|
variant_files: Dict[str, List[Tuple[bytes, str]]] = {}
|
||||||
for sibling in r.siblings:
|
for sibling in r.siblings:
|
||||||
if sibling.rfilename.startswith("build/torch"):
|
if sibling.rfilename.startswith("build/torch"):
|
||||||
if sibling.blob_id is None:
|
if sibling.blob_id is None:
|
||||||
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
|
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
|
||||||
|
|
||||||
file_locks.append(
|
path = Path(sibling.rfilename)
|
||||||
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
|
variant = path.parts[1]
|
||||||
)
|
filename = Path(*path.parts[2:])
|
||||||
|
|
||||||
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
|
hash = sibling.lfs.sha256 if sibling.lfs is not None else sibling.blob_id
|
||||||
|
|
||||||
|
files = variant_files.setdefault(variant, [])
|
||||||
|
|
||||||
|
# Encode as posix for consistent slash handling, then encode
|
||||||
|
# as utf-8 for byte-wise sorting later.
|
||||||
|
files.append((filename.as_posix().encode("utf-8"), hash))
|
||||||
|
|
||||||
|
variant_locks = {}
|
||||||
|
for variant, files in variant_files.items():
|
||||||
|
m = hashlib.sha256()
|
||||||
|
for filename_bytes, hash in sorted(files):
|
||||||
|
# Filename as bytes.
|
||||||
|
m.update(filename_bytes)
|
||||||
|
# Git blob or LFS file hash as bytes.
|
||||||
|
m.update(bytes.fromhex(hash))
|
||||||
|
|
||||||
|
variant_locks[variant] = VariantLock(hash=f"sha256-{m.hexdigest()}")
|
||||||
|
|
||||||
|
return KernelLock(repo_id=repo_id, sha=r.sha, variants=variant_locks)
|
||||||
|
|
||||||
|
|
||||||
def write_egg_lockfile(cmd, basename, filename):
|
def write_egg_lockfile(cmd, basename, filename):
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
cwd = Path.cwd()
|
cwd = Path.cwd()
|
||||||
with open(cwd / "pyproject.toml", "rb") as f:
|
pyproject_path = cwd / "pyproject.toml"
|
||||||
|
if not pyproject_path.exists():
|
||||||
|
# Nothing to do if the project doesn't have pyproject.toml.
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(pyproject_path, "rb") as f:
|
||||||
data = tomllib.load(f)
|
data = tomllib.load(f)
|
||||||
|
|
||||||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
||||||
if kernel_versions is None:
|
if kernel_versions is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
lock_path = cwd / "hf-kernels.lock"
|
lock_path = cwd / "kernels.lock"
|
||||||
if not lock_path.exists():
|
if not lock_path.exists():
|
||||||
logging.warning(f"Lock file {lock_path} does not exist")
|
logging.warning(f"Lock file {lock_path} does not exist")
|
||||||
# Ensure that the file gets deleted in editable installs.
|
# Ensure that the file gets deleted in editable installs.
|
||||||
290
src/kernels/utils.py
Normal file
290
src/kernels/utils.py
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
import ctypes
|
||||||
|
import hashlib
|
||||||
|
import importlib
|
||||||
|
import importlib.metadata
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
from importlib.metadata import Distribution
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from packaging.version import parse
|
||||||
|
|
||||||
|
from kernels.lockfile import KernelLock, VariantLock
|
||||||
|
|
||||||
|
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||||
|
|
||||||
|
|
||||||
|
def build_variant() -> str:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.version.cuda is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"This kernel requires CUDA to be installed. Torch was not compiled with CUDA enabled."
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_version = parse(torch.__version__)
|
||||||
|
cuda_version = parse(torch.version.cuda)
|
||||||
|
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||||
|
cpu = platform.machine()
|
||||||
|
os = platform.system().lower()
|
||||||
|
|
||||||
|
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-cu{cuda_version.major}{cuda_version.minor}-{cpu}-{os}"
|
||||||
|
|
||||||
|
|
||||||
|
def universal_build_variant() -> str:
|
||||||
|
# Once we support other frameworks, detection goes here.
|
||||||
|
return "torch-universal"
|
||||||
|
|
||||||
|
|
||||||
|
def import_from_path(module_name: str, file_path: Path) -> ModuleType:
|
||||||
|
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
||||||
|
# it would also be used for other imports. So, we make a module name that
|
||||||
|
# depends on the path for it to be unique using the hex-encoded hash of
|
||||||
|
# the path.
|
||||||
|
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
||||||
|
module_name = f"{module_name}_{path_hash}"
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
if module is None:
|
||||||
|
raise ImportError(f"Cannot load module {module_name} from spec")
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module) # type: ignore
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def install_kernel(
|
||||||
|
repo_id: str,
|
||||||
|
revision: str,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||||
|
) -> Tuple[str, Path]:
|
||||||
|
"""
|
||||||
|
Download a kernel for the current environment to the cache.
|
||||||
|
|
||||||
|
The output path is validated againt `hash` when set.
|
||||||
|
"""
|
||||||
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
universal_variant_path = repo_path / "build" / universal_variant
|
||||||
|
|
||||||
|
if not variant_path.exists() and universal_variant_path.exists():
|
||||||
|
# Fall back to universal variant.
|
||||||
|
variant = universal_variant
|
||||||
|
variant_path = universal_variant_path
|
||||||
|
|
||||||
|
if variant_locks is not None:
|
||||||
|
variant_lock = variant_locks.get(variant)
|
||||||
|
if variant_lock is None:
|
||||||
|
raise ValueError(f"No lock found for build variant: {variant}")
|
||||||
|
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
|
||||||
|
|
||||||
|
module_init_path = variant_path / package_name / "__init__.py"
|
||||||
|
|
||||||
|
if not os.path.exists(module_init_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Kernel `{repo_id}` at revision {revision} does not have build: {variant}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return package_name, variant_path
|
||||||
|
|
||||||
|
|
||||||
|
def install_kernel_all_variants(
|
||||||
|
repo_id: str,
|
||||||
|
revision: str,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
variant_locks: Optional[Dict[str, VariantLock]] = None,
|
||||||
|
) -> Path:
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns="build/*",
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
revision=revision,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if variant_locks is not None:
|
||||||
|
for entry in (repo_path / "build").iterdir():
|
||||||
|
variant = entry.parts[-1]
|
||||||
|
|
||||||
|
variant_lock = variant_locks.get(variant)
|
||||||
|
if variant_lock is None:
|
||||||
|
raise ValueError(f"No lock found for build variant: {variant}")
|
||||||
|
|
||||||
|
validate_kernel(
|
||||||
|
repo_path=repo_path, variant=variant, hash=variant_lock.hash
|
||||||
|
)
|
||||||
|
|
||||||
|
return repo_path / "build"
|
||||||
|
|
||||||
|
|
||||||
|
def get_kernel(repo_id: str, revision: str = "main") -> ModuleType:
|
||||||
|
package_name, package_path = install_kernel(repo_id, revision=revision)
|
||||||
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def load_kernel(repo_id: str) -> ModuleType:
|
||||||
|
"""Get a pre-downloaded, locked kernel."""
|
||||||
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Kernel `{repo_id}` is not locked. Please lock it with `kernels lock <project>` and then reinstall the project."
|
||||||
|
)
|
||||||
|
|
||||||
|
package_name = package_name_from_repo_id(repo_id)
|
||||||
|
|
||||||
|
variant = build_variant()
|
||||||
|
universal_variant = universal_build_variant()
|
||||||
|
|
||||||
|
repo_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id,
|
||||||
|
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
|
||||||
|
cache_dir=CACHE_DIR,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
universal_variant_path = repo_path / "build" / universal_variant
|
||||||
|
if not variant_path.exists() and universal_variant_path.exists():
|
||||||
|
# Fall back to universal variant.
|
||||||
|
variant = universal_variant
|
||||||
|
variant_path = universal_variant_path
|
||||||
|
|
||||||
|
module_init_path = variant_path / package_name / "__init__.py"
|
||||||
|
if not os.path.exists(module_init_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
|
||||||
|
)
|
||||||
|
|
||||||
|
return import_from_path(package_name, variant_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
|
||||||
|
"""Get a kernel using a lock file."""
|
||||||
|
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||||
|
|
||||||
|
if locked_sha is None:
|
||||||
|
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
||||||
|
|
||||||
|
package_name, package_path = install_kernel(
|
||||||
|
repo_id, locked_sha, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
|
||||||
|
return import_from_path(package_name, package_path / package_name / "__init__.py")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
||||||
|
for dist in _get_caller_distributions():
|
||||||
|
lock_json = dist.read_text("kernels.lock")
|
||||||
|
if lock_json is not None:
|
||||||
|
for kernel_lock_json in json.loads(lock_json):
|
||||||
|
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||||
|
if kernel_lock.repo_id == repo_id:
|
||||||
|
return kernel_lock.sha
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_distributions() -> List[Distribution]:
|
||||||
|
module = _get_caller_module()
|
||||||
|
if module is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Look up all possible distributions that this module could be from.
|
||||||
|
package = module.__name__.split(".")[0]
|
||||||
|
dist_names = importlib.metadata.packages_distributions().get(package)
|
||||||
|
if dist_names is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_caller_module() -> Optional[ModuleType]:
|
||||||
|
stack = inspect.stack()
|
||||||
|
# Get first module in the stack that is not the current module.
|
||||||
|
first_module = inspect.getmodule(stack[0][0])
|
||||||
|
for frame in stack[1:]:
|
||||||
|
module = inspect.getmodule(frame[0])
|
||||||
|
if module is not None and module != first_module:
|
||||||
|
return module
|
||||||
|
return first_module
|
||||||
|
|
||||||
|
|
||||||
|
def validate_kernel(*, repo_path: Path, variant: str, hash: str):
|
||||||
|
"""Validate the given build variant of a kernel against a hasht."""
|
||||||
|
variant_path = repo_path / "build" / variant
|
||||||
|
|
||||||
|
# Get the file paths. The first element is a byte-encoded relative path
|
||||||
|
# used for sorting. The second element is the absolute path.
|
||||||
|
files: List[Tuple[bytes, Path]] = []
|
||||||
|
# Ideally we'd use Path.walk, but it's only available in Python 3.12.
|
||||||
|
for dirpath, _, filenames in os.walk(variant_path):
|
||||||
|
for filename in filenames:
|
||||||
|
file_abs = Path(dirpath) / filename
|
||||||
|
|
||||||
|
# Python likes to create files when importing modules from the
|
||||||
|
# cache, only hash files that are symlinked blobs.
|
||||||
|
if file_abs.is_symlink():
|
||||||
|
files.append(
|
||||||
|
(
|
||||||
|
file_abs.relative_to(variant_path).as_posix().encode("utf-8"),
|
||||||
|
file_abs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
for filename_bytes, full_path in sorted(files):
|
||||||
|
m.update(filename_bytes)
|
||||||
|
|
||||||
|
blob_filename = full_path.resolve().name
|
||||||
|
if len(blob_filename) == 40:
|
||||||
|
# SHA-1 hashed, so a Git blob.
|
||||||
|
m.update(git_hash_object(full_path.read_bytes()))
|
||||||
|
elif len(blob_filename) == 64:
|
||||||
|
# SHA-256 hashed, so a Git LFS blob.
|
||||||
|
m.update(hashlib.sha256(full_path.read_bytes()).digest())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected blob filename length: {len(blob_filename)}")
|
||||||
|
|
||||||
|
computedHash = f"sha256-{m.hexdigest()}"
|
||||||
|
if computedHash != hash:
|
||||||
|
raise ValueError(
|
||||||
|
f"Lock file specifies kernel with hash {hash}, but downloaded kernel has hash: {computedHash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def git_hash_object(data: bytes, object_type: str = "blob"):
|
||||||
|
"""Calculate git SHA1 of data."""
|
||||||
|
header = f"{object_type} {len(data)}\0".encode()
|
||||||
|
m = hashlib.sha1()
|
||||||
|
m.update(header)
|
||||||
|
m.update(data)
|
||||||
|
return m.digest()
|
||||||
|
|
||||||
|
|
||||||
|
def package_name_from_repo_id(repo_id: str) -> str:
|
||||||
|
return repo_id.split("/")[-1].replace("-", "_")
|
||||||
66
tests/hash_validation/kernels.lock
Normal file
66
tests/hash_validation/kernels.lock
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-community/activation",
|
||||||
|
"sha": "6a030420d0dd33ffdc1281afc8ae8e94b4f4f9d0",
|
||||||
|
"variants": {
|
||||||
|
"torch25-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-3e39de10721a6b21806834fc95c96526b9cfe2c2052829184f2d3fa48ef5849d",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx11-cu121-x86_64-linux": {
|
||||||
|
"hash": "sha256-b0dee22c65bb277fa8150f9ea3fc90e2b1c11f84b5d760bbf4ab9c7a4b102e58",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx11-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-8960cf857d641d591a7c2d4264925cc2bf7b4a6f9d738b74082b2fb0806db19a",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-0496e04c2900a2dc7ab0f3b95fe8ce9da69faab6b5ca3f55ddd62c26c81268d0",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu121-x86_64-linux": {
|
||||||
|
"hash": "sha256-172b793b24dfed3dcb9adc7d3487f260c05b310c598fc6ee8abb3e230c59a0a8",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch25-cxx98-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-12f5e66f32dc4cf4b21f43f76efad198556024da67a1ce28e88ea2d49ad8bdcc",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-bb70e2f36f0b4d12868956c2ad713c756570ff0e0eb4cf7fc3a78ebde617975b",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-a745732eb9ec5d6a54565dbeec5b3c983cc6aa072a4a2576ab2fef9b2a600005",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx11-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-1160684ca09c065864f27c5c110281807a1ec31d603bf05fcb974e9e7cfe35cc",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu118-x86_64-linux": {
|
||||||
|
"hash": "sha256-24459d068943b93e4d55e94811469bf7e850d7958785132b108f1240724b846f",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu124-x86_64-linux": {
|
||||||
|
"hash": "sha256-5b009ba63ab6d52ac1aaf70057a2d0fa6ea5d1788a2416111be02103c6bcaaaf",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
},
|
||||||
|
"torch26-cxx98-cu126-x86_64-linux": {
|
||||||
|
"hash": "sha256-05128889b4bdaf9ef58f3c07d93218deaa08e06f9121931b47efef8826482e4a",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"repo_id": "kernels-community/triton-scaled-mm",
|
||||||
|
"sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f",
|
||||||
|
"variants": {
|
||||||
|
"torch-universal": {
|
||||||
|
"hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52",
|
||||||
|
"hash_type": "git_lfs_concat"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
3
tests/hash_validation/pyproject.toml
Normal file
3
tests/hash_validation/pyproject.toml
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[tool.kernels.dependencies]
|
||||||
|
"kernels-community/activation" = ">=0.0.2"
|
||||||
|
"kernels-community/triton-scaled-mm" = ">=0.0.2"
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from hf_kernels import get_kernel
|
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -8,6 +9,11 @@ def kernel():
|
|||||||
return get_kernel("kernels-community/activation")
|
return get_kernel("kernels-community/activation")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def universal_kernel():
|
||||||
|
return get_kernel("kernels-community/triton-scaled-mm")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def device():
|
def device():
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
@ -28,3 +34,17 @@ def test_gelu_fast(kernel, device):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert torch.allclose(y, expected)
|
assert torch.allclose(y, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_universal_kernel(universal_kernel):
|
||||||
|
torch.manual_seed(0)
|
||||||
|
A = torch.randint(-10, 10, (64, 128), dtype=torch.int8, device="cuda")
|
||||||
|
B = torch.randint(-10, 10, (128, 96), dtype=torch.int8, device="cuda")
|
||||||
|
scale_a = torch.tensor(0.4, dtype=torch.float16, device="cuda")
|
||||||
|
scale_b = torch.tensor(0.6, dtype=torch.float16, device="cuda")
|
||||||
|
|
||||||
|
out = universal_kernel.triton_scaled_mm(A, B, scale_a, scale_b, torch.float16)
|
||||||
|
out_check = (A * scale_a) @ (B * scale_b)
|
||||||
|
out_check = out_check.to(torch.float16)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from hf_kernels import get_kernel
|
|
||||||
|
from kernels import get_kernel
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
21
tests/test_hash_validation.py
Normal file
21
tests/test_hash_validation.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from kernels.cli import download_kernels
|
||||||
|
|
||||||
|
|
||||||
|
# Mock download arguments class.
|
||||||
|
@dataclass
|
||||||
|
class DownloadArgs:
|
||||||
|
all_variants: bool
|
||||||
|
project_dir: Path
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_hash_validation():
|
||||||
|
project_dir = Path(__file__).parent / "hash_validation"
|
||||||
|
download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_all_hash_validation():
|
||||||
|
project_dir = Path(__file__).parent / "hash_validation"
|
||||||
|
download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir))
|
||||||
Reference in New Issue
Block a user