From ce77658efcdf57282b2a9600f0ada519015acbcd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 16 Oct 2025 19:31:00 +0530 Subject: [PATCH] fix: kernels upload to a repo branch (#168) * fix: kernels upload to a repo branch * up --- src/kernels/cli.py | 6 +++++- tests/test_kernel_upload.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/kernels/cli.py b/src/kernels/cli.py index 4ea55a5..b10e3b2 100644 --- a/src/kernels/cli.py +++ b/src/kernels/cli.py @@ -4,7 +4,7 @@ import json import sys from pathlib import Path -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, upload_folder, create_branch from kernels.compat import tomllib from kernels.lockfile import KernelLock, get_kernel_locks @@ -204,6 +204,7 @@ def lock_kernels(args): def upload_kernels(args): + # Resolve `kernel_dir` to be uploaded. kernel_dir = Path(args.kernel_dir).resolve() build_dir = kernel_dir / "build" if not kernel_dir.is_dir(): @@ -215,6 +216,9 @@ def upload_kernels(args): repo_id=args.repo_id, private=args.private, exist_ok=True ).repo_id + if args.branch is not None: + create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True) + delete_patterns: set[str] = set() for build_variant in build_dir.iterdir(): if build_variant.is_dir(): diff --git a/tests/test_kernel_upload.py b/tests/test_kernel_upload.py index d962c92..f755239 100644 --- a/tests/test_kernel_upload.py +++ b/tests/test_kernel_upload.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import List import pytest -from huggingface_hub import delete_repo, model_info +from huggingface_hub import delete_repo, model_info, list_repo_refs from kernels.cli import upload_kernels @@ -83,6 +83,11 @@ def test_kernel_upload_works_as_expected(branch): repo_filenames = get_filenames_from_a_repo(REPO_ID) assert any(str(script_path.name) for f in repo_filenames) + + if branch is not None: + refs = list_repo_refs(repo_id=REPO_ID) + assert any(ref_branch.name == branch for ref_branch in refs.branches) + delete_repo(repo_id=REPO_ID) @@ -95,7 +100,7 @@ def test_kernel_upload_deletes_as_expected(): build_dir.mkdir(parents=True, exist_ok=True) script_path = build_dir / "foo_2025.py" script_path.write_text(PY_CONTENT) - upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) + upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None)) repo_filenames = get_filenames_from_a_repo(REPO_ID) filename_to_change = get_filename_to_change(repo_filenames) @@ -107,7 +112,7 @@ def test_kernel_upload_deletes_as_expected(): changed_filename = next_filename(Path(filename_to_change)) script_path = build_dir / changed_filename script_path.write_text(PY_CONTENT) - upload_kernels(UploadArgs(tmpdir, REPO_ID, False)) + upload_kernels(UploadArgs(tmpdir, REPO_ID, False, None)) repo_filenames = get_filenames_from_a_repo(REPO_ID) assert any(str(changed_filename) in k for k in repo_filenames), f"{repo_filenames=}"