mirror of
https://github.com/huggingface/kernels.git
synced 2025-10-21 05:30:30 +08:00
Compare commits
43 Commits
build_syst
...
v0.1.6
Author | SHA1 | Date | |
---|---|---|---|
3212affd9e | |||
7ff40a859c | |||
cf64113c8b | |||
ba4f88f5aa | |||
d61971ad46 | |||
d7f3831992 | |||
03875be8a0 | |||
e41ef2358e | |||
aca3ce7dfb | |||
3bae6fca7d | |||
cffbafa61f | |||
29b27a58cf | |||
bee46be22b | |||
e05ba73534 | |||
544354cb97 | |||
105704b910 | |||
ea518db3d9 | |||
b88f3b107f | |||
60864349af | |||
9a04c2fa91 | |||
b6ae897c4d | |||
c9d6ba261a | |||
ef362cbbd0 | |||
c336be09bb | |||
e476ca406c | |||
4723d7914e | |||
2706669b75 | |||
10e4692a6b | |||
2cde348805 | |||
4da2e1a3dd | |||
b24ccdcf67 | |||
d6e807c081 | |||
b74005dd70 | |||
1619b2523d | |||
63cbbf71dc | |||
af55097d46 | |||
8747b3fbe2 | |||
c5ad392b77 | |||
cbe41bc9ec | |||
433fcc5268 | |||
7f75050a8a | |||
14b9350f3c | |||
8ef2f2fb5b |
45
.github/workflows/test.yml
vendored
Normal file
45
.github/workflows/test.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Test hf-kernels
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
types: [opened, synchronize, reopened] # trigger on PRs
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Run tests
|
||||
runs-on:
|
||||
group: aws-g6-24xlarge
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
max-parallel: 4
|
||||
matrix:
|
||||
python-version: ["3.10", "3.12"]
|
||||
torch-version: ["2.5.1", "2.6.0"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv and set the python version
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Lock Torch version
|
||||
run: uv lock --upgrade-package "torch==${{ matrix.torch-version }}"
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests
|
59
README.md
59
README.md
@ -1,11 +1,11 @@
|
||||
# kernels
|
||||
# hf-kernels
|
||||
|
||||
Make sure you have `torch==2.5.1+cu124` installed.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
from hf_kernels import get_kernel
|
||||
|
||||
# Download optimized kernels from the Hugging Face hub
|
||||
activation = get_kernel("kernels-community/activation")
|
||||
@ -19,3 +19,58 @@ activation.gelu_fast(y, x)
|
||||
|
||||
print(y)
|
||||
```
|
||||
|
||||
## Docker Reference
|
||||
|
||||
build and run the reference [example/basic.py](example/basic.py) in a Docker container with the following commands:
|
||||
|
||||
```bash
|
||||
docker build --platform linux/amd64 -t kernels-reference -f docker/Dockerfile.reference .
|
||||
docker run --gpus all -it --rm -e HF_TOKEN=$HF_TOKEN kernels-reference
|
||||
```
|
||||
|
||||
## Locking kernel versions
|
||||
|
||||
Projects that use `setuptools` can lock the kernel versions that should be
|
||||
used. First specify the accepted versions in `pyproject.toml` and make
|
||||
sure that `hf-kernels` is a build dependency:
|
||||
|
||||
```toml
|
||||
[build-system]
|
||||
requires = ["hf-kernels", "setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.kernels.dependencies]
|
||||
"kernels-community/activation" = ">=0.0.1"
|
||||
```
|
||||
|
||||
Then run `hf-kernel lock .` in the project directory. This generates a `kernels.lock` file with
|
||||
the locked revisions. The locked revision will be used when loading a kernel with
|
||||
`get_locked_kernel`:
|
||||
|
||||
```python
|
||||
from hf_kernels import get_locked_kernel
|
||||
|
||||
activation = get_locked_kernel("kernels-community/activation")
|
||||
```
|
||||
|
||||
**Note:** the lock file is included in the package metadata, so it will only be visible
|
||||
to `hf-kernels` after doing an (editable or regular) installation of your project.
|
||||
|
||||
## Pre-downloading locked kernels
|
||||
|
||||
Locked kernels can be pre-downloaded by running `hf-kernel download .` in your
|
||||
project directory. This will download the kernels to your local Hugging Face
|
||||
Hub cache.
|
||||
|
||||
The pre-downloaded kernels are used by the `get_locked_kernel` function.
|
||||
`get_locked_kernel` will download a kernel when it is not pre-downloaded. If you
|
||||
want kernel loading to error when a kernel is not pre-downloaded, you can use
|
||||
the `load_kernel` function instead:
|
||||
|
||||
````python
|
||||
```python
|
||||
from hf_kernels import load_kernel
|
||||
|
||||
activation = load_kernel("kernels-community/activation")
|
||||
````
|
||||
|
51
docker/Dockerfile.reference
Normal file
51
docker/Dockerfile.reference
Normal file
@ -0,0 +1,51 @@
|
||||
FROM nvidia/cuda:12.4.0-devel-ubuntu22.04
|
||||
|
||||
# set environment vars
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH="/root/.local/bin:/root/.cargo/bin:${PATH}"
|
||||
|
||||
# install system deps
|
||||
RUN apt-get update && apt-get install -y \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
python3 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# install git-lfs
|
||||
RUN git lfs install
|
||||
|
||||
# install uv
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# initialize uv and create virtual env
|
||||
RUN uv init --app kernel-test
|
||||
|
||||
# move into the app
|
||||
WORKDIR /app/kernel-test
|
||||
|
||||
# install python depdencies
|
||||
RUN uv add torch==2.5.0 numpy
|
||||
|
||||
# copy hf-kernels lib
|
||||
COPY src ./hf-kernels/src
|
||||
COPY pyproject.toml ./hf-kernels/pyproject.toml
|
||||
COPY README.md ./hf-kernels/README.md
|
||||
|
||||
# install library
|
||||
RUN uv pip install -e hf-kernels
|
||||
|
||||
# copy examples
|
||||
COPY examples ./examples
|
||||
|
||||
# set the nvidia runtime env
|
||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
|
||||
# command to run the script
|
||||
CMD ["uv", "run", "examples/basic.py"]
|
||||
# CMD ["ls", "hf-kernels"]
|
30
examples/basic.py
Normal file
30
examples/basic.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
|
||||
from hf_kernels import get_kernel
|
||||
|
||||
print("Starting examples/basic.py demo")
|
||||
|
||||
# Download optimized kernels from the Hugging Face hub
|
||||
activation = get_kernel("kernels-community/activation")
|
||||
|
||||
print("Activation kernel fetched")
|
||||
|
||||
# Create tensor
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device="cuda").view(3, 3)
|
||||
print("Input tensor created")
|
||||
|
||||
# Run the kernel
|
||||
y = torch.empty_like(x)
|
||||
activation.gelu_fast(y, x)
|
||||
|
||||
print("Kernel successfully executed")
|
||||
|
||||
# Check results
|
||||
expected = torch.tensor([
|
||||
[0.8408, 1.9551, 2.9961],
|
||||
[4.0000, 5.0000, 6.0000],
|
||||
[7.0000, 8.0000, 9.0000]
|
||||
], device='cuda:0', dtype=torch.float16)
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
print("Calculated values are exact")
|
@ -1,22 +1,40 @@
|
||||
[project]
|
||||
name = "hf-kernels"
|
||||
version = "0.1.0"
|
||||
version = "0.1.6"
|
||||
description = "Download cuda kernels"
|
||||
authors = [
|
||||
{name = "OlivierDehaene", email = "olivier@huggingface.co"},
|
||||
{name = "Daniel de Kok", email = "daniel@huggingface.co"},
|
||||
{name = "David Holtz", email = "david@huggingface.co"},
|
||||
{name = "Nicolas Patry", email = "nicolas@huggingface.co"}
|
||||
{ name = "OlivierDehaene", email = "olivier@huggingface.co" },
|
||||
{ name = "Daniel de Kok", email = "daniel@huggingface.co" },
|
||||
{ name = "David Holtz", email = "david@huggingface.co" },
|
||||
{ name = "Nicolas Patry", email = "nicolas@huggingface.co" },
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
python = "^3.9"
|
||||
huggingface-hub = "^0.26.3"
|
||||
packaging = "^24.2"
|
||||
tomli = { version = "^2.0.1", python = "<3.11" }
|
||||
requires-python = ">= 3.9"
|
||||
dependencies = [
|
||||
"huggingface-hub>=0.26.3",
|
||||
"packaging>=24.2",
|
||||
"tomli>=2.0.1; python_version<'3.11'",
|
||||
"torch>=2.4",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["torch", "huggingface_hub", "numpy"]
|
||||
build-backend = "hf_kernels.build"
|
||||
backend-path = ["src"]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest >=8",
|
||||
# Whatever version is compatible with pytest.
|
||||
"pytest-benchmark",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
hf-kernels = "hf_kernels.cli:main"
|
||||
|
||||
[project.entry-points."egg_info.writers"]
|
||||
"hf-kernels.lock" = "hf_kernels.lockfile:write_egg_lockfile"
|
||||
|
||||
#[build-system]
|
||||
#requires = ["torch", "huggingface_hub", "numpy", "tomli;python_version<='3.10'"]
|
||||
#build-backend = "hf_kernels.build"
|
||||
#backend-path = ["src"]
|
||||
|
@ -1,3 +1,3 @@
|
||||
from hf_kernels.utils import get_kernel, load_kernel, install_kernel
|
||||
from hf_kernels.utils import get_kernel, install_kernel, load_kernel, get_locked_kernel
|
||||
|
||||
__all__ = ["get_kernel", "load_kernel", "install_kernel"]
|
||||
__all__ = ["get_kernel", "get_locked_kernel", "load_kernel", "install_kernel"]
|
||||
|
@ -16,6 +16,8 @@ don't require importing typing but then quote them so earlier Python version ign
|
||||
them while IDEs and type checker can see through the quotes.
|
||||
"""
|
||||
|
||||
from hf_kernels.compat import tomllib
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence # noqa:I001
|
||||
@ -40,7 +42,6 @@ def call(
|
||||
warn_config_settings(config_settings)
|
||||
# Unlike `find_uv_bin`, this mechanism must work according to PEP 517
|
||||
import os
|
||||
import tomllib
|
||||
|
||||
cwd = os.getcwd()
|
||||
filename = os.path.join(cwd, "pyproject.toml")
|
||||
@ -48,7 +49,7 @@ def call(
|
||||
data = tomllib.load(f)
|
||||
|
||||
for kernel, _ in (
|
||||
data.get("tool", {}).get("kernels", {}).get("dependencies", {}).items()
|
||||
data.get("tool", {}).get("hf-kernels", {}).get("dependencies", {}).items()
|
||||
):
|
||||
from hf_kernels.utils import install_kernel
|
||||
|
||||
|
83
src/hf_kernels/cli.py
Normal file
83
src/hf_kernels/cli.py
Normal file
@ -0,0 +1,83 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from hf_kernels.compat import tomllib
|
||||
from hf_kernels.lockfile import KernelLock, get_kernel_locks
|
||||
from hf_kernels.utils import install_kernel, install_kernel_all_variants
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hf-kernel", description="Manage compute kernels"
|
||||
)
|
||||
subparsers = parser.add_subparsers(required=True)
|
||||
|
||||
download_parser = subparsers.add_parser("download", help="Download locked kernels")
|
||||
download_parser.add_argument(
|
||||
"project_dir",
|
||||
type=Path,
|
||||
help="The project directory",
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--all-variants",
|
||||
action="store_true",
|
||||
help="Download all build variants of the kernel",
|
||||
)
|
||||
download_parser.set_defaults(func=download_kernels)
|
||||
|
||||
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
|
||||
lock_parser.add_argument(
|
||||
"project_dir",
|
||||
type=Path,
|
||||
help="The project directory",
|
||||
)
|
||||
lock_parser.set_defaults(func=lock_kernels)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
||||
def download_kernels(args):
|
||||
lock_path = args.project_dir / "hf-kernels.lock"
|
||||
|
||||
if not lock_path.exists():
|
||||
print(f"No hf-kernels.lock file found in: {args.project_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
with open(args.project_dir / "hf-kernels.lock", "r") as f:
|
||||
lock_json = json.load(f)
|
||||
|
||||
for kernel_lock_json in lock_json:
|
||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||
print(
|
||||
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if args.all_variants:
|
||||
install_kernel_all_variants(kernel_lock.repo_id, kernel_lock.sha)
|
||||
else:
|
||||
install_kernel(kernel_lock.repo_id, kernel_lock.sha)
|
||||
|
||||
|
||||
def lock_kernels(args):
|
||||
with open(args.project_dir / "pyproject.toml", "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
||||
|
||||
all_locks = []
|
||||
for kernel, version in kernel_versions.items():
|
||||
all_locks.append(get_kernel_locks(kernel, version))
|
||||
|
||||
with open(args.project_dir / "hf-kernels.lock", "w") as f:
|
||||
json.dump(all_locks, f, cls=_JSONEncoder, indent=2)
|
||||
|
||||
|
||||
class _JSONEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if dataclasses.is_dataclass(o):
|
||||
return dataclasses.asdict(o)
|
||||
return super().default(o)
|
8
src/hf_kernels/compat.py
Normal file
8
src/hf_kernels/compat.py
Normal file
@ -0,0 +1,8 @@
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
import tomli as tomllib
|
||||
|
||||
__all__ = ["tomllib"]
|
112
src/hf_kernels/lockfile.py
Normal file
112
src/hf_kernels/lockfile.py
Normal file
@ -0,0 +1,112 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import InvalidVersion, Version
|
||||
|
||||
from hf_kernels.compat import tomllib
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileLock:
|
||||
filename: str
|
||||
blob_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelLock:
|
||||
repo_id: str
|
||||
sha: str
|
||||
files: List[FileLock]
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, o: Dict):
|
||||
files = [FileLock(**f) for f in o["files"]]
|
||||
return cls(repo_id=o["repo_id"], sha=o["sha"], files=files)
|
||||
|
||||
|
||||
def _get_available_versions(repo_id: str):
|
||||
"""Get kernel versions that are available in the repository."""
|
||||
versions = {}
|
||||
for tag in HfApi().list_repo_refs(repo_id).tags:
|
||||
if not tag.name.startswith("v"):
|
||||
continue
|
||||
try:
|
||||
versions[Version(tag.name[1:])] = tag
|
||||
except InvalidVersion:
|
||||
continue
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def get_kernel_locks(repo_id: str, version_spec: str):
|
||||
"""
|
||||
Get the locks for a kernel with the given version spec.
|
||||
|
||||
The version specifier can be any valid Python version specifier:
|
||||
https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers
|
||||
"""
|
||||
versions = _get_available_versions(repo_id)
|
||||
requirement = SpecifierSet(version_spec)
|
||||
accepted_versions = sorted(requirement.filter(versions.keys()))
|
||||
|
||||
if len(accepted_versions) == 0:
|
||||
raise ValueError(
|
||||
f"No version of `{repo_id}` satisfies requirement: {version_spec}"
|
||||
)
|
||||
|
||||
tag_for_newest = versions[accepted_versions[-1]]
|
||||
|
||||
r = HfApi().repo_info(
|
||||
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
|
||||
)
|
||||
if r.sha is None:
|
||||
raise ValueError(
|
||||
f"Cannot get commit SHA for repo {repo_id} for tag {tag_for_newest.name}"
|
||||
)
|
||||
|
||||
if r.siblings is None:
|
||||
raise ValueError(
|
||||
f"Cannot get sibling information for {repo_id} for tag {tag_for_newest.name}"
|
||||
)
|
||||
|
||||
file_locks = []
|
||||
for sibling in r.siblings:
|
||||
if sibling.rfilename.startswith("build/torch"):
|
||||
if sibling.blob_id is None:
|
||||
raise ValueError(f"Cannot get blob ID for {sibling.rfilename}")
|
||||
|
||||
file_locks.append(
|
||||
FileLock(filename=sibling.rfilename, blob_id=sibling.blob_id)
|
||||
)
|
||||
|
||||
return KernelLock(repo_id=repo_id, sha=r.sha, files=file_locks)
|
||||
|
||||
|
||||
def write_egg_lockfile(cmd, basename, filename):
|
||||
import logging
|
||||
|
||||
cwd = Path.cwd()
|
||||
pyproject_path = cwd / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
# Nothing to do if the project doesn't have pyproject.toml.
|
||||
return
|
||||
|
||||
with open(pyproject_path, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
|
||||
kernel_versions = data.get("tool", {}).get("kernels", {}).get("dependencies", None)
|
||||
if kernel_versions is None:
|
||||
return
|
||||
|
||||
lock_path = cwd / "hf-kernels.lock"
|
||||
if not lock_path.exists():
|
||||
logging.warning(f"Lock file {lock_path} does not exist")
|
||||
# Ensure that the file gets deleted in editable installs.
|
||||
data = None
|
||||
else:
|
||||
data = open(lock_path, "r").read()
|
||||
|
||||
cmd.write_or_delete_file(basename, filename, data)
|
@ -1,19 +1,27 @@
|
||||
import ctypes
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import os
|
||||
from importlib.metadata import Distribution
|
||||
from types import ModuleType
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from packaging.version import parse
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
import tomli as tomllib
|
||||
from hf_kernels.compat import tomllib
|
||||
from hf_kernels.lockfile import KernelLock
|
||||
|
||||
CACHE_DIR: Optional[str] = os.environ.get("HF_KERNELS_CACHE", None)
|
||||
|
||||
|
||||
def build_variant():
|
||||
import torch
|
||||
|
||||
torch_version = parse(torch.__version__)
|
||||
cuda_version = parse(torch.version.cuda)
|
||||
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
|
||||
@ -24,6 +32,12 @@ def build_variant():
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path):
|
||||
# We cannot use the module name as-is, after adding it to `sys.modules`,
|
||||
# it would also be used for other imports. So, we make a module name that
|
||||
# depends on the path for it to be unique using the hex-encoded hash of
|
||||
# the path.
|
||||
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
|
||||
module_name = f"{module_name}_{path_hash}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
@ -31,16 +45,43 @@ def import_from_path(module_name: str, file_path):
|
||||
return module
|
||||
|
||||
|
||||
def install_kernel(repo_id: str, revision: str):
|
||||
package_name = get_metadata(repo_id)["torch"]["name"]
|
||||
def install_kernel(repo_id: str, revision: str, local_files_only: bool = False):
|
||||
package_name = get_metadata(repo_id, revision, local_files_only=local_files_only)[
|
||||
"torch"
|
||||
]["name"]
|
||||
repo_path = snapshot_download(
|
||||
repo_id, allow_patterns=f"build/{build_variant()}/*", revision=revision
|
||||
repo_id,
|
||||
allow_patterns=f"build/{build_variant()}/*",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return package_name, f"{repo_path}/build/{build_variant()}"
|
||||
|
||||
|
||||
def get_metadata(repo_id: str):
|
||||
with open(hf_hub_download(repo_id, "build.toml"), "rb") as f:
|
||||
def install_kernel_all_variants(
|
||||
repo_id: str, revision: str, local_files_only: bool = False
|
||||
):
|
||||
snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns="build/*",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
|
||||
def get_metadata(repo_id: str, revision: str, local_files_only: bool = False):
|
||||
with open(
|
||||
hf_hub_download(
|
||||
repo_id,
|
||||
"build.toml",
|
||||
cache_dir=CACHE_DIR,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
),
|
||||
"rb",
|
||||
) as f:
|
||||
return tomllib.load(f)
|
||||
|
||||
|
||||
@ -49,13 +90,74 @@ def get_kernel(repo_id: str, revision: str = "main"):
|
||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
||||
|
||||
|
||||
def load_kernel(repo_id: str, revision: str = "main"):
|
||||
def load_kernel(repo_id: str):
|
||||
"""Get a pre-downloaded, locked kernel."""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
if locked_sha is None:
|
||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
||||
|
||||
filename = hf_hub_download(
|
||||
repo_id, "build.toml", local_files_only=True, revision=revision
|
||||
repo_id,
|
||||
"build.toml",
|
||||
cache_dir=CACHE_DIR,
|
||||
local_files_only=True,
|
||||
revision=locked_sha,
|
||||
)
|
||||
with open(filename, "rb") as f:
|
||||
metadata = tomllib.load(f)
|
||||
package_name = metadata["torch"]["name"]
|
||||
|
||||
repo_path = os.path.dirname(filename)
|
||||
package_path = f"{repo_path}/build/{build_variant()}"
|
||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
||||
|
||||
|
||||
def get_locked_kernel(repo_id: str, local_files_only: bool = False):
|
||||
"""Get a kernel using a lock file."""
|
||||
locked_sha = _get_caller_locked_kernel(repo_id)
|
||||
|
||||
if locked_sha is None:
|
||||
raise ValueError(f"Kernel `{repo_id}` is not locked")
|
||||
|
||||
package_name, package_path = install_kernel(
|
||||
repo_id, locked_sha, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
return import_from_path(package_name, f"{package_path}/{package_name}/__init__.py")
|
||||
|
||||
|
||||
def _get_caller_locked_kernel(repo_id: str) -> Optional[str]:
|
||||
for dist in _get_caller_distributions():
|
||||
lock_json = dist.read_text("hf-kernels.lock")
|
||||
if lock_json is not None:
|
||||
for kernel_lock_json in json.loads(lock_json):
|
||||
kernel_lock = KernelLock.from_json(kernel_lock_json)
|
||||
if kernel_lock.repo_id == repo_id:
|
||||
return kernel_lock.sha
|
||||
return None
|
||||
|
||||
|
||||
def _get_caller_distributions() -> List[Distribution]:
|
||||
module = _get_caller_module()
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
# Look up all possible distributions that this module could be from.
|
||||
package = module.__name__.split(".")[0]
|
||||
dist_names = importlib.metadata.packages_distributions().get(package)
|
||||
if dist_names is None:
|
||||
return []
|
||||
|
||||
return [importlib.metadata.distribution(dist_name) for dist_name in dist_names]
|
||||
|
||||
|
||||
def _get_caller_module() -> Optional[ModuleType]:
|
||||
stack = inspect.stack()
|
||||
# Get first module in the stack that is not the current module.
|
||||
first_module = inspect.getmodule(stack[0][0])
|
||||
for frame in stack[1:]:
|
||||
module = inspect.getmodule(frame[0])
|
||||
if module is not None and module != first_module:
|
||||
return module
|
||||
return first_module
|
||||
|
30
tests/test_basic.py
Normal file
30
tests/test_basic.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_kernels import get_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("No CUDA")
|
||||
return "cuda"
|
||||
|
||||
|
||||
def test_gelu_fast(kernel, device):
|
||||
x = torch.arange(1, 10, dtype=torch.float16, device=device).view(3, 3)
|
||||
y = torch.empty_like(x)
|
||||
|
||||
kernel.gelu_fast(y, x)
|
||||
|
||||
expected = torch.tensor(
|
||||
[[0.8408, 1.9551, 2.9961], [4.0000, 5.0000, 6.0000], [7.0000, 8.0000, 9.0000]],
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch.allclose(y, expected)
|
33
tests/test_benchmarks.py
Normal file
33
tests/test_benchmarks.py
Normal file
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_kernels import get_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kernel():
|
||||
return get_kernel("kernels-community/activation")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("No CUDA")
|
||||
return "cuda"
|
||||
|
||||
|
||||
def test_gelu_small(kernel, device, benchmark):
|
||||
x = torch.randn(32, 32, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
def test_gelu_medium(kernel, device, benchmark):
|
||||
x = torch.randn(128, 128, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
||||
|
||||
|
||||
def test_gelu_large(kernel, device, benchmark):
|
||||
x = torch.randn(512, 512, dtype=torch.float16, device=device)
|
||||
y = torch.empty_like(x)
|
||||
benchmark(kernel.gelu_fast, y, x)
|
Reference in New Issue
Block a user