diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e8c206e..d194d85 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,7 +51,10 @@ jobs: run: uv run mypy src/kernels - name: Run tests - run: uv run pytest tests + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + uv run pytest tests - name: Check kernel conversion run: | diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 6067331..de0154f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -21,6 +21,8 @@ title: Kernels - local: api/layers title: Layers + - local: cli + title: kernels CLI title: API Reference - sections: - local: kernel-requirements diff --git a/docs/source/cli.md b/docs/source/cli.md new file mode 100644 index 0000000..323fe32 --- /dev/null +++ b/docs/source/cli.md @@ -0,0 +1,15 @@ +# Kernels CLI Reference + +## Main Functions + +### kernels upload + +Use `kernels upload --repo_id="hub-username/kernel"` to upload +your kernel builds to the Hub. + +**Notes**: + +* This will take care of creating a repository on the Hub with the `repo_id` provided. +* If a repo with the `repo_id` already exists and if it contains a `build` with the build variant +being uploaded, it will attempt to delete the files existing under it. +* Make sure to be authenticated (run `hf auth login` if not) to be able to perform uploads to the Hub. \ No newline at end of file diff --git a/src/kernels/cli.py b/src/kernels/cli.py index 50e42e1..90f5db5 100644 --- a/src/kernels/cli.py +++ b/src/kernels/cli.py @@ -4,6 +4,8 @@ import json import sys from pathlib import Path +from huggingface_hub import create_repo, upload_folder + from kernels.compat import tomllib from kernels.lockfile import KernelLock, get_kernel_locks from kernels.utils import install_kernel, install_kernel_all_variants @@ -31,6 +33,24 @@ def main(): ) download_parser.set_defaults(func=download_kernels) + upload_parser = subparsers.add_parser("upload", help="Upload kernels to the Hub") + upload_parser.add_argument( + "kernel_dir", + type=Path, + help="Directory of the kernel build", + ) + upload_parser.add_argument( + "--repo_id", + type=str, + help="Repository ID to use to upload to the Hugging Face Hub", + ) + upload_parser.add_argument( + "--private", + action="store_true", + help="If the repository should be private.", + ) + upload_parser.set_defaults(func=upload_kernels) + lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") lock_parser.add_argument( "project_dir", @@ -153,6 +173,33 @@ def lock_kernels(args): json.dump(all_locks, f, cls=_JSONEncoder, indent=2) +def upload_kernels(args): + kernel_dir = Path(args.kernel_dir).resolve() + build_dir = kernel_dir / "build" + if not kernel_dir.is_dir(): + raise ValueError(f"{kernel_dir} is not a directory") + if not build_dir.is_dir(): + raise ValueError("Couldn't find `build` directory inside `kernel_dir`") + + repo_id = create_repo( + repo_id=args.repo_id, private=args.private, exist_ok=True + ).repo_id + + delete_patterns: set[str] = set() + for build_variant in build_dir.iterdir(): + if build_variant.is_dir(): + delete_patterns.add(f"{build_variant.name}/**") + + upload_folder( + repo_id=repo_id, + folder_path=build_dir, + path_in_repo="build", + delete_patterns=list(delete_patterns), + commit_message="Build uploaded using `kernels`.", + ) + print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.") + + class _JSONEncoder(json.JSONEncoder): def default(self, o): if dataclasses.is_dataclass(o): diff --git a/tests/test_kernel_upload.py b/tests/test_kernel_upload.py new file mode 100644 index 0000000..dbe20ac --- /dev/null +++ b/tests/test_kernel_upload.py @@ -0,0 +1,92 @@ +import logging +import os +import re +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import pytest +from huggingface_hub import model_info + +from kernels.cli import upload_kernels + +REPO_ID = "kernels-test/kernels-upload-test" + +PY_CONTENT = """\ +#!/usr/bin/env python3 + +def main(): + print("Hello from torch-universal!") + +if __name__ == "__main__": + main() +""" + + +@dataclass +class UploadArgs: + kernel_dir: None + repo_id: None + private: False + + +def next_filename(path: Path) -> Path: + """ + Given a path like foo_2050.py, return foo_2051.py. + """ + m = re.match(r"^(.*?)(\d+)(\.py)$", path.name) + if not m: + raise ValueError( + f"Filename {path.name!r} does not match pattern _.py" + ) + + prefix, number, suffix = m.groups() + new_number = str(int(number) + 1).zfill(len(number)) + return path.with_name(f"{prefix}{new_number}{suffix}") + + +def get_filename_to_change(repo_filenames): + for f in repo_filenames: + if "foo" in f and f.endswith(".py"): + filename_to_change = os.path.basename(f) + break + assert filename_to_change + return filename_to_change + + +def get_filenames_from_a_repo(repo_id: str) -> List[str]: + try: + repo_info = model_info(repo_id=repo_id, files_metadata=True) + repo_siblings = repo_info.siblings + if repo_siblings is not None: + return [f.rfilename for f in repo_siblings] + else: + raise ValueError("No repo siblings found.") + except Exception as e: + logging.error(f"Error connecting to the Hub: {e}.") + + +@pytest.mark.xfail( + condition=os.environ.get("GITHUB_ACTIONS") == "true", + reason="There is something weird when writing to the Hub from a GitHub CI.", + strict=True, +) +def test_kernel_upload_deletes_as_expected(): + repo_filenames = get_filenames_from_a_repo(REPO_ID) + filename_to_change = get_filename_to_change(repo_filenames) + + with tempfile.TemporaryDirectory() as tmpdir: + path = f"{tmpdir}/build/torch-universal/upload_test" + build_dir = Path(path) + build_dir.mkdir(parents=True, exist_ok=True) + changed_filename = next_filename(Path(filename_to_change)) + script_path = build_dir / changed_filename + script_path.write_text(PY_CONTENT) + upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) + + repo_filenames = get_filenames_from_a_repo(REPO_ID) + assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}" + assert not any( + str(filename_to_change) in k for k in repo_filenames + ), f"{repo_filenames=}"